diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..df12dc43f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,18 @@ +--- +name: Bug report +about: Tell us about a bug you found +title: '' +labels: bug +assignees: '' + +--- + + + +### What version are you using? + +### What did you do? + +### What happened? + +### What did you expect to see instead? diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..7efb43a2a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,14 @@ +blank_issues_enabled: true +contact_links: + - name: Frontend Project + url: https://github.com/stellar/stellar-disbursement-platform-frontend + about: The frontend project for this application. + - name: Stellar Laboratory + url: https://laboratory.stellar.org/#?network=test + about: The best place to experiment with the Stellar network. + - name: Docker Images + url: https://hub.docker.com/r/stellar/stellar-disbursement-platform-backend + about: Where to check the available Docker images that have been published. + - name: Stellar Ecosystem Proposals (SEPs) + url: https://github.com/stellar/stellar-protocol + about: The SEPs implemented in this project are defined here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..673ebab1a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,16 @@ +--- +name: Feature request +about: Tell us what you'd like to see +title: 'Feature Request: ' +labels: '' +assignees: '' + +--- + + + +### What problem does your feature solve? + +### What would you like to see? + +### What alternatives are there? diff --git a/.github/ISSUE_TEMPLATE/release_a_new_version.md b/.github/ISSUE_TEMPLATE/release_a_new_version.md new file mode 100644 index 000000000..be559ec60 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/release_a_new_version.md @@ -0,0 +1,44 @@ +--- +name: Release a New Version! +about: Prepare a release to be launched +title: '' +labels: release +--- + + + +## Release Checklist + +> Attention: the examples below use the version `x.y.z` but you should update them to use the version you're releasing. + +### Git Preparation + +- [ ] Decide on a version number based on the current version number and the common rules defined in [Semantic Versioning](https://semver.org). E.g. `x.y.z`. +- [ ] Update this ticket name to reflect the new version number, following the pattern "Release `x.y.z`". +- [ ] Cut a branch for the new release out of the `develop` branch, following the gitflow naming pattern `release/x.y.z`. + +### Code Preparation + +- [ ] Update the code to use this version number. + - [ ] Update `version` and `appVersion` in [helmchart/sdp/Chart.yaml]. + - [ ] Update the constant `Version` in [main.go] +- [ ] Update the [CHANGELOG.md] file with the new version number and release notes. +- [ ] Update the version and optionally the status in the header of the [README.md] file. +- [ ] Run tests and linting, and make sure the version running in the nain branch is working end-to-end. At least the minimal end-to-end manual tests is mandatory. +- [ ] 🚨 DO NOT RELEASE before holidays or weekends! Mondays and Tuesdays are preferred. + +### Merging the Branches + +- [ ] When the team is confident the release is stable, you'll need to create two pull requests: + - [ ] `release/x.y.z -> main`: 🚨 Do not squash-and-merge! This PR should be merged with a merge commit. + - [ ] `release/x.y.z -> develop`: this should be merged after the `main` branch is merged. 🚨 Do not squash-and-merge! This PR should be merged with a merge commit. + +### Publishing the Release + +- [ ] After the release branch is merged to `main`, create a new release on GitHub with the name `x.y.z` and the use the same changes from the [CHANGELOG.md] file. + - [ ] The release should automatically publish a new version of the docker image to Docker Hub. Double check if that happened. + +[main.go]: ../../main.go +[README.md]: ../../README.md +[helmchart/sdp/Chart.yaml]: ../../helmchart/sdp/Chart.yaml +[CHANGELOG.md]: ../../CHANGELOG.md diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..5b0ddf53a --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,48 @@ +### What + +[TODO: Short statement about what is changing.] + +### Why + +[TODO: Why this change is being made. Include any context required to understand the why.] + +### Known limitations + +[TODO or N/A] + +### Checklist + +#### PR Structure + +* [ ] This PR has reasonably narrow scope (if not, break it down into smaller PRs). +* [ ] This PR does not mix refactoring changes with feature changes (split into two PRs otherwise). +* [ ] This PR's title starts with the name of the package, area, or subject affected by the change. + +#### Thoroughness + +* [ ] This PR adds tests for the new functionality or fixes. +* [ ] This PR contains the link to the Jira ticket it addresses. + +#### Configs and Secrets + +* [ ] No new CONFIG variables are required -OR- the new required ones were added to the helmchart's [`values.yaml`] file. +* [ ] No new CONFIG variables are required -OR- the new required ones were added to the deployments ([`pr-preview`], [`dev`], [`demo`], `prd`). +* [ ] No new SECRETS variables are required -OR- the new required ones were mentioned in the helmchart's [`values.yaml`] file. +* [ ] No new SECRETS variables are required -OR- the new required ones were added to the deployments ([`pr-preview secrets`], [`dev secrets`], [`demo secrets`], `prd secrets`). + +#### Release + +* [ ] This is not a breaking change. +* [ ] **This is ready for production.**. If your PR is not ready for production, please consider opening additional complementary PRs using this one as the base. Only merge this into `develop` or `main` after it's ready for production! + +#### Deployment + +* [ ] Does the deployment work after merging? + +[`values.yaml`]: ../helmchart/sdp/values.yaml +[`pr-preview`]: https://github.com/stellar/kube/blob/d3e4f5dd8aa4c13b45a31a5a937f3e98841171a7/kube001-dev/namespaces/common-previews/stellar-disbursement-platform/backend-helm-values +[`dev`]: https://github.com/stellar/kube/blob/d3e4f5dd8aa4c13b45a31a5a937f3e98841171a7/kube001-dev/namespaces/stellar-disbursement-platform/backend-helm-values +[`demo`]: https://github.com/stellar/kube/blob/d3e4f5dd8aa4c13b45a31a5a937f3e98841171a7/kube001-dev/namespaces/stellar-disbursement-platform/demo/demo-backend-helm-values +[`pr-preview secrets`]: https://github.com/stellar/kube/blob/d3e4f5dd8aa4c13b45a31a5a937f3e98841171a7/kube001-dev/namespaces/common-previews/externalsecrets-common-previews.yaml#L241-L346 +[`dev secrets`]: https://github.com/stellar/kube/blob/d3e4f5dd8aa4c13b45a31a5a937f3e98841171a7/kube001-dev/namespaces/stellar-disbursement-platform/stellar-disbursement-platform-externalsecrets.yaml +[`demo secrets`]: https://github.com/stellar/kube/blob/d3e4f5dd8aa4c13b45a31a5a937f3e98841171a7/kube001-dev/namespaces/stellar-disbursement-platform/demo/demo-sdp-externalsecrets.yaml diff --git a/.github/workflows/anchor_platform_integration_check.yml b/.github/workflows/anchor_platform_integration_check.yml new file mode 100644 index 000000000..2890220ee --- /dev/null +++ b/.github/workflows/anchor_platform_integration_check.yml @@ -0,0 +1,56 @@ +name: SDP<>AnchorPlatform Integration + +on: + push: + branches: + - main + - develop + - "release/**" + - "releases/**" + - "hotfix/**" + pull_request: + workflow_call: # allows this workflow to be called from another workflow + +jobs: + anchor-integration: + runs-on: ubuntu-latest + environment: "Anchor Integration Tests" + env: + DISTRIBUTION_PUBLIC_KEY: ${{ secrets.DISTRIBUTION_PUBLIC_KEY }} + DISTRIBUTION_SEED: ${{ secrets.DISTRIBUTION_SEED }} + SEP10_SIGNING_PUBLIC_KEY: ${{ secrets.SEP10_SIGNING_PUBLIC_KEY }} + SEP10_SIGNING_PRIVATE_KEY: ${{ secrets.SEP10_SIGNING_PRIVATE_KEY }} + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Run Docker Compose for SDP and Anchor Platform + working-directory: dev + run: docker-compose -f docker-compose-sdp-anchor.yml down && docker-compose -f docker-compose-sdp-anchor.yml up --build -d + + - name: Install curl + run: sudo apt-get update && sudo apt-get install -y curl + + - name: Wait for localhost:8080/health + timeout-minutes: 5 + run: | + until curl --output /dev/null --silent --head --fail http://localhost:8080/health; do + echo 'Waiting for anchor-platform to be up and running...' + sleep 15 + done + echo 'Anchor-platform is up and running.' + + - name: Install NodeJs + uses: actions/setup-node@v2 + with: + node-version: 14 + + - name: Anchor Validation Tests (@stellar/anchor-tests) + run: | + npm install -g @stellar/anchor-tests + stellar-anchor-tests --home-domain http://localhost:8080 --seps 1 10 + + - name: Docker logs + if: always() + working-directory: dev + run: docker-compose -f docker-compose-sdp-anchor.yml logs && docker-compose -f docker-compose-sdp-anchor.yml down diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..9be938430 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,111 @@ +name: Go + +on: + push: + branches: + - main + - develop + - "release/**" + - "releases/**" + - "hotfix/**" + pull_request: + workflow_call: # allows this workflow to be called from another workflow + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: golangci-lint + uses: golangci/golangci-lint-action@08e2f20817b15149a52b5b3ebe7de50aff2ba8c5 # version v3.4.0 + with: + version: v1.52.2 # this is the golangci-lint version + args: --timeout 5m0s + + - name: Run ./gomod.sh + run: ./gomod.sh + + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Build Project + run: go build ./... + + test: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:12-alpine + env: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + PGHOST: localhost + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + env: + PGHOST: localhost + PGPORT: 5432 + PGUSER: postgres + PGPASSWORD: postgres + PGDATABASE: postgres + DATABASE_URL: postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Setup Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Run tests + run: go test -race -coverpkg=./... -coverprofile=c.out ./... + + - name: Validate Test Coverage Threshold + env: + TESTCOVERAGE_THRESHOLD: 83 # percentage + run: | + echo "Quality Gate: Checking if test coverage is above threshold..." + echo "Threshold: $TESTCOVERAGE_THRESHOLD%" + totalCoverage=`./scripts/exclude_from_coverage.sh && go tool cover -func=c.out | grep total: | grep -Eo '[0-9]+\.[0-9]+'` + echo "Test Coverage: $totalCoverage%" + echo "-------------------------" + if (( $(echo "$totalCoverage $TESTCOVERAGE_THRESHOLD" | awk '{print ($1 >= $2)}') )); then + echo " $totalCoverage% > $TESTCOVERAGE_THRESHOLD%" + echo "Current test coverage is above threshold πŸŽ‰πŸŽ‰πŸŽ‰! Please keep up the good work!" + else + echo " $totalCoverage% < $TESTCOVERAGE_THRESHOLD%" + echo "🚨 Current test coverage is below threshold 😱! Please add more unit tests or adjust threshold to a lower value." + echo "Failed 😭" + exit 1 + fi + + complete: + if: always() + needs: [check, build, test] + runs-on: ubuntu-latest + steps: + - if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') + run: exit 1 diff --git a/.github/workflows/docker_image_public_release.yml b/.github/workflows/docker_image_public_release.yml new file mode 100644 index 000000000..999568c25 --- /dev/null +++ b/.github/workflows/docker_image_public_release.yml @@ -0,0 +1,107 @@ +# This workflow publishes a new docker image to 'https://hub.docker.com/r/stellar/stellar-disbursement-platform-backend' +# when a new release is created or when we merge something to the develop branch. +name: Docker Image Public Release + +on: + release: + types: + - published + push: + branches: + - develop + +jobs: + tests: + uses: ./.github/workflows/ci.yml # execute the callable ci.yml + secrets: inherit # pass all secrets + + anchor_platform_integration_check: + uses: ./.github/workflows/anchor_platform_integration_check.yml # execute the callable anchor_platform_integration_check.yml + needs: + - tests + secrets: inherit # pass all secrets + + e2e_integration_test: + uses: ./.github/workflows/e2e_integration_test.yml # execute the callable e2e_integration_test.yml + needs: + - tests + secrets: inherit # pass all secrets + + build_and_push_docker_image_on_release: + if: github.event_name == 'release' + name: Push to DockerHub (release prd) # stellar/stellar-disbursement-platform-backend:{VERSION} + runs-on: ubuntu-latest + needs: + - tests + - anchor_platform_integration_check + - e2e_integration_test + steps: + - name: Check if tag is not empty + run: | + if [[ -z "${{ github.event.release.tag_name }}" ]]; then + echo "Release tag name cannot be empty." + exit 1 + fi + + - uses: actions/checkout@v3 + + - name: Login to DockerHub + uses: docker/login-action@v2.2.0 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push to DockerHub (release prd) + uses: docker/build-push-action@v4.1.1 + with: + push: true + build-args: | + GIT_COMMIT=${{ github.event.release.tag_name }} + tags: stellar/stellar-disbursement-platform-backend:${{ github.event.release.tag_name }},stellar/stellar-disbursement-platform-backend:latest + file: Dockerfile + + build_and_push_docker_image_on_dev_push: + if: github.event_name == 'push' && github.ref == 'refs/heads/develop' + name: Push to DockerHub (release develop branch) # stellar/stellar-disbursement-platform-backend:edge-{DATE}-{SHA} + runs-on: ubuntu-latest + needs: + - tests + - anchor_platform_integration_check + - e2e_integration_test + steps: + - uses: actions/checkout@v3 + + - name: Login to DockerHub + uses: docker/login-action@v2.2.0 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Get current date + id: get_date + run: echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT + + - name: Get SHA + shell: bash + id: get_sha + run: echo "SHA=$(git rev-parse --short ${{ github.sha }} )" >> $GITHUB_OUTPUT + + - name: Build and push to DockerHub (develop branch) + uses: docker/build-push-action@v4.1.1 + with: + push: true + build-args: | + GIT_COMMIT=${{ steps.get_sha.outputs.SHA }} + tags: stellar/stellar-disbursement-platform-backend:edge,stellar/stellar-disbursement-platform-backend:edge-${{ steps.get_date.outputs.DATE }}-${{ steps.get_sha.outputs.SHA }} + file: Dockerfile + + complete: + if: always() + needs: + - build_and_push_docker_image_on_release + - build_and_push_docker_image_on_dev_push + runs-on: ubuntu-latest + steps: + - if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') + run: exit 1 + # TODO: figure out which job failed and print the logs diff --git a/.github/workflows/e2e_integration_test.yml b/.github/workflows/e2e_integration_test.yml new file mode 100644 index 000000000..3bd0e70df --- /dev/null +++ b/.github/workflows/e2e_integration_test.yml @@ -0,0 +1,70 @@ +name: E2E integration test + +on: + push: + branches: + - main + - develop + - "release/**" + - "releases/**" + - "hotfix/**" + pull_request: + workflow_call: # allows this workflow to be called from another workflow + +env: + USER_EMAIL: "sdp_user@stellar.org" + USER_PASSWORD: "mockPassword123!" + +jobs: + e2e-integration-test: + runs-on: ubuntu-latest + environment: "Receiver Registration - E2E Integration Tests" + env: + DISTRIBUTION_PUBLIC_KEY: ${{ secrets.DISTRIBUTION_PUBLIC_KEY }} + DISTRIBUTION_SEED: ${{ secrets.DISTRIBUTION_SEED }} + SEP10_SIGNING_PUBLIC_KEY: ${{ secrets.SEP10_SIGNING_PUBLIC_KEY }} + SEP10_SIGNING_PRIVATE_KEY: ${{ secrets.SEP10_SIGNING_PRIVATE_KEY }} + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Cleanup data + working-directory: internal/integrationtests + run: docker-compose -f docker-compose-e2e-tests.yml down -v + + - name: Run Docker Compose for SDP, Anchor Platform and TSS + working-directory: internal/integrationtests + run: docker-compose -f docker-compose-e2e-tests.yml up --build -V -d + + - name: Install curl + run: sudo apt-get update && sudo apt-get install -y curl + + - name: Create authenticated user + run: | + docker exec e2e-sdp-api bash -c "echo '$USER_PASSWORD' | ./stellar-disbursement-platform auth add-user '$USER_EMAIL' joe yabuki --password --owner --roles owner" + + - name: Create integration test data + run: | + docker exec e2e-sdp-api bash -c "./stellar-disbursement-platform integration-tests create-data" + + - name: Restart anchor platform + run: | + docker restart e2e-anchor-platform + + - name: Wait for anchor platform localhost:8080/health + timeout-minutes: 5 + run: | + until curl --output /dev/null --silent --head --fail http://localhost:8080/health; do + echo 'Waiting for anchor-platform to be up and running...' + sleep 15 + done + echo 'Anchor-platform is up and running.' + + - name: Start integration test command + run: | + docker exec e2e-sdp-api bash -c "./stellar-disbursement-platform integration-tests start" + + - name: Docker logs + if: always() + working-directory: internal/integrationtests + run: docker-compose -f docker-compose-e2e-tests.yml logs && docker-compose -f docker-compose-e2e-tests.yml down diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..02e4ef67a --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Environment file +.env + +# subproject used for testing: +v1_compatibility/stellar-relief-backoffice-backend + +# Project binary: +stellar-disbursement-platform-backend + +# Text Editors +.vscode diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..66282e9cc --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,12 @@ +linters: + # Enable these linters in addition to the default linters that golangci-lint starts with. + enable: + - gofmt + - gofumpt + - govet + +linters-settings: + gofmt: + simplify: true + govet: + check-shadowing: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..ec19eb079 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,43 @@ +# How to contribute + +πŸ‘πŸŽ‰ First off, thanks for taking the time to contribute! πŸŽ‰πŸ‘ + +Check out the [Stellar Contribution Guide](https://github.com/stellar/.github/blob/master/CONTRIBUTING.md) that applies to all Stellar projects. + +## Style guides + +### Issues + +* Ensure the issue was not already reported by searching on GitHub under Issues. +* Issues start with: + * The functional area most affected, ex. `disbursements: fix... `. + * Or, `ci:` when changes or an issue are isolated to CI. + * Or, `doc:` when changes or an issue are isolated to non-code documentation not limited to a single package. +* Label issues with `bug` if they're clearly a bug. +* Label issues with `feature request` if they're a feature request. + +### Pull Requests + +* **Title:** PR titles start with feat, fix, refactor, ci, or doc, followed by a short description of the change. +* **Branching:** PRs must be opened against the `develop` branch. +* **Scope:** PRs must be focused and not contain unrelated commits. +* **Refactoring:** Explicitly differentiate refactoring PRs and feature PRs. Refactoring PRs don’t change functionality. They usually touch a lot more code, and are reviewed in less detail. Avoid refactoring in feature PRs. +* **Go Formatting:** Ensure your code is formatted with `gofmt`. +* **Tests:** Ensure your change is covered by tests. If you're adding a new feature or fixing a bug, you must add tests. If you're refactoring, you should add tests if possible. +* **Documentation:** Update README.md or other relevant documentation pages if necessary. For exported functions, types, and constants, make sure to add a doc comment conforming to [Effective Go](https://golang.org/doc/effective_go.html#commentary). +* **Best Practices:** * Follow [Effective Go](https://golang.org/doc/effective_go.html) and [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments). + + +### Git Commit Messages + +* Use the present tense ("Add feature" not "Added feature"). +* Use the imperative mood ("Move cursor to..." not "Moves cursor to..."). +* Start commit message with the relevant issue number, e.g., #123 Fixed bug in XYZ module. + +## Development Environment + +All SDP services can be started using the `docker-compose.yml` file in the `dev` directory. Please refer to the [README](dev/README.md) for more information. + +## Code of Conduct + +Help us keep Stellar open and inclusive. Please read and follow our [Code of Conduct](https://github.com/stellar/.github/blob/master/CODE_OF_CONDUCT.md). diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..7f7132bfb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +# To build: +# make docker-build +# To push: +# make docker-push + +FROM golang:1.20-bullseye as build +ARG GIT_COMMIT + +WORKDIR /src/stellar-disbursement-platform +ADD go.mod go.sum ./ +RUN go mod download +ADD . ./ +RUN go build -o /bin/stellar-disbursement-platform -ldflags "-X main.GitCommit=$GIT_COMMIT" . + + +FROM ubuntu:22.04 + +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates +# ADD migrations/ /app/migrations/ +COPY --from=build /bin/stellar-disbursement-platform /app/ +EXPOSE 8001 +WORKDIR /app +ENTRYPOINT ["./stellar-disbursement-platform"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..7a4a3ea24 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..dd722264c --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +# Check if we need to prepend docker command with sudo +SUDO := $(shell docker version >/dev/null 2>&1 || echo "sudo") + +# If LABEL is not provided set default value +LABEL ?= $(shell git rev-parse --short HEAD)$(and $(shell git status -s),-dirty-$(shell id -u -n)) +# If TAG is not provided set default value +TAG ?= stellar/stellar-disbursement-platform:$(LABEL) +# https://github.com/opencontainers/image-spec/blob/master/annotations.md +BUILD_DATE := $(shell date -u +%FT%TZ) + +docker-build: + $(SUDO) docker build --pull --label org.opencontainers.image.created="$(BUILD_DATE)" -t $(TAG) --build-arg GIT_COMMIT=$(LABEL) . + +docker-push: + $(SUDO) docker push $(TAG) + +go-install: + go build -o $(GOPATH)/bin/stellar-disbursement-platform -ldflags "-X main.GitCommit=$(LABEL)" . \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..8dbf1b96e --- /dev/null +++ b/README.md @@ -0,0 +1,302 @@ +# Stellar Disbursement Platform Backend + +## Table of Contents + +- [Introduction](#introduction) +- [Install](#install) +- [Quick Start](#quick-start) +- [Architecture](#architecture) + - [Core](#core) + - [Transaction Submission Service](#transaction-submission-service) + - [Database](#database) +- [SDP Operators](#sdp-operators) + - [Supporting New Wallets](#supporting-new-wallets) +- [Wallets](#wallets) + - [Recipient Registration Experience](#recipient-registration-experience) + - [Deferred Deep Links](#deferred-deep-links) + - [Wallet Registration Procedure](#wallet-registration-procedure) +- [Contributors](#contributors) + - [State Transitions](#state-transitions) + +## Introduction + +The Stellar Disbursement Platform (SDP) enables organizations to disburse bulk payments to recipients using Stellar. + +Throughout this documentation, we'll define "users" as members of the organization using the SDP to make payments, while defining "recipients" as those receiving payments. + +## Install + +Install golang and make sure `$GOPATH/bin` is in your `$PATH`. Then run the following. + +``` sh +git clone git@github.com:stellar/stellar-disbursement-platform-backend.git +cd stellar-disbursement-platform-backend +make go-install +stellar-disbursement-platform --help +``` + +## Quick Start + +To quickly test the SDP using preconfigured values, see the [Quick Start Guide](./dev/README.md). + +## Architecture + +![high_level_architecture](./docs/images/high_level_architecture.png) + +The [SDP Dashboard][sdp-dashboard] and [Anchor Platform](https://github.com/stellar/java-stellar-anchor-sdk) components are separate projects that must be installed and configured alongside the services included in this project. + +In a future iteration of this project, the Transaction Submission Service (TSS) will also be moved to its own repository to be used as an independent service. At that point, this project will include the services contained in the Core module shown in the diagram above. + +### Core + +The SDP Core service include several components started using a single command. + +```sh +stellar-disbursement-platform serve --help +``` + +#### Dashboard API + +The Dashboard API is the component responsible for enabling clients to interact with the SDP. The primary client is the [SDP Dashboard][sdp-dashboard], but other clients can use the API as well. + +##### Metrics + +The Dashboard API component is also responsible for exporting system and application metrics. We only have support for `Prometheus` at the moment, but we can add new monitors clients in the future. + +#### Message Service + +The Message Service sends messages to users and recipients for the following reasons: + +- Informing recipients they have an incoming disbursement and need to register +- Providing one-time passcodes (OTPs) to recipients +- Sending emails to users during account creation and account recovery flows + +Note that the Message Service requires that both SMS and email services are configured. For emails, AWS SES is supported. For SMS messages to recipients, Twilio is supported. AWS SNS support is not integrated yet. + +If you're using the `AWS_EMAIL` sender type, you'll need to verify the email address you're using to send emails in order to prevent it from being flagged by email firewalls. You can do that by following the instructions in [this link](https://docs.aws.amazon.com/ses/latest/dg/email-authentication-methods.html). + +#### Wallet Registration UI + +The Wallet Registration UI is also hosted by the core server, and enables recipients to confirm their phone number and other information used to verify their identity. Once recipients have registered through this UI, the Transaction Submission Server (TSS) immediately makes the payment to the recpients registered Stellar account. + +### Transaction Submission Service + +Refer to documentation [here](/internal/transactionsubmission/README.md). + +#### Core + TSS Integration + +Currently, Core and Transaction Submission Service (TSS) interact at the database layer, sharing the `submitter_transactions` table to read and write state. The interaction is as follows: + +1. Core inserts rows into the `submitter_transactions` table, queuing payments +2. The TSS polls the `submitter_transactions` table, detecting payments +3. For each payment detected, the TSS creates and submits a transaction to the Stellar network, monitoring its state until it is confirmed to have been included in a ledger or failed with a nonrecoverable error +4. Core's Dashboard API reads from the `submitter_transactions` table on demand to fetch the state of each payment + +In future iterations of the project, the Transaction Submission Service will provide an API for clients such as the SDP to use for queuing and polling the state of transactions. + +### Database + +To manage the migrations of the database, use the `db` subcommand. + +```sh +stellar-disbursement-platform db --help +``` + +Note that there is an `auth` subcommand that has its own `migrate` sub-subcommand. Operators of the SDP will need to ensure migrations for both the core and auth components are run. + +```sh +stellar-disbursement-platform db migrate up +stellar-disbursement-platform db auth migrate up +``` + +#### Core Tables + +The tables below are used to facilitate disbursements. + +![core schema](./docs/images/core_schema.png) + +The tables below are used to manage user roles and organizational information. + +![admin schema](./docs/images/admin_schema.png) + +#### TSS Tables + +The tables below are shared by the transaction submission service and core service. + +![tss schema](./docs/images/tss_schema.png) + +Note that the `submitter_transactions` table is used by the TSS and will be managed by the service when moved to its own project. + +## SDP Operators + +### Supporting New Wallets + +Adding support for new wallets involves registering a new wallet in the database and correctly populating the `deep_link_schema` column. + +Additionally, ensure the `sep_10_client_domain` column is correctly filled, matching the domain where the wallet provides a [SEP-10] authentication endpoint. 🚨 Note that this step is crucial for verifying the recipient's authentication from a trusted wallet. It's optional in the testnet environment. + +When adding a wallet, you also need to provide the wallet `name` and `homepage`. + +## Wallets + +### Recipient Registration Experience + +The recipient experience is as follows: + +1. The recipient receives an SMS message notifying them they have a payment waiting from the organization and prompts them to click a [deep link] to open or install a wallet application +1. When the recipient opens the app, the wallet immediately onboards the recipient, creates a Stellar account and trustline for them, initiates a [SEP-24] deposit transaction with the SDP, and opens the SDP's registration webpage as an overlay screen/iframe inside the app. +1. The user confirms their phone number and date of birth and is prompted to return to the wallet application +1. The user receives the payment within seconds + +### Deferred Deep Links + +Most likely, the intended recipient will not have the necessary wallet application installed on their device. For this reason, wallets should support the concept of [deferred deep linking], which enables the following flow: + +1. The recipient's initial action of clicking the deep link should redirect them to the appropriate app store to download the wallet application. +1. After installing and opening the application, the recpient should be rerouted to the wallet's typical onboarding flow. +1. Once the user has successfully onboarded, the wallet should use the information included in the deep link to kick off the [Wallet Registration Procedure](#wallet-registration-procedure). + +Deferred deep linking is a feature commonly supported by numerous mobile deep linking solutions, there are third-party services that can be used to implement this functionality, such as Singular, Branch, AppsFlyer, Adjust, and others. [Here](https://medium.com/bumble-tech/universal-links-for-android-and-ios-1ddb1e70cab0) is a post with more information on how to implement deferred deep linking. + +The registration link sent to recpients follows this format + +```url +https://?asset=&domain=&name=&signature= +``` + +- `asset`: the Stellar asset +- `domain`: the domain hosting the SDP's `stellar.toml` file +- `name`: the name of the organization sending payments +- `signature`: a signature from the SDP's [SEP-10] signing key + +> Note that the deep link is specific to each SDP, payer org, and asset. It is not specific per individual receiver. There is no risk in sharing the link with receivers who are part of the same disbursement. The link will be the same for multiple receivers and they will proove their identity as part of the [SEP-24] deposit flow. + +Below is an example of a registration link (signed) + +```url +https://vibrantapp.com/sdp-dev?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar+Test&signature=fea6c5e805a29b903835bea2f6c60069113effdf1c5cb448d4948573c65557b1d667bcd176c24a94ed9d54a1829317c74f39319076511512a3e697b4b746ae0a +``` + +In this example, the host is `https://vibrantapp.com/sdp-dev` and the signature is the result of signing the below (unsigned) url using the [SEP-10] signing key `SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5`, with the public key being `GBFDUUZ5ZYC6RAPOQLM7IYXLFHYTMCYXBGM7NIC4EE2MWOSGIYCOSN5F`: + +```url +https://vibrantapp.com/sdp-dev?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar+Test +``` + +In this example, the signature is `fea6c5e805a29b903835bea2f6c60069113effdf1c5cb448d4948573c65557b1d667bcd176c24a94ed9d54a1829317c74f39319076511512a3e697b4b746ae0a`. + +Below is a JavaScript snippet demonstrating how to verify the signature: + +```js +#!/usr/bin/env node + +const { Keypair } = require("stellar-sdk"); + +// The SDP's stellar.toml SIGNING_KEY +// +// For security, this should ideally be fetched from +// https:///.well-known/stellar.toml on demand +const keypair = Keypair.fromPublicKey( + "GBFDUUZ5ZYC6RAPOQLM7IYXLFHYTMCYXBGM7NIC4EE2MWOSGIYCOSN5F" +); +console.log("public key:", keypair.publicKey()); + +let url = + "https://aidtestnet.netlify.app/aid?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar+Test" +let signature = + "d8f5c9f0ece3118488d1546e1cb4071327a4f7f4f3efd5deefb8e92d668fca28504da8861f260c6ede26624d7a5bc244be1cf17c011e1994e3f45e2f19ea9b01" + +console.log( + "verified:", + keypair.verify( + Buffer.from(url.toString(), "utf8"), + Buffer.from(signature, "hex"), + ), +); +``` + +### Wallet Registration Procedure + +1. Confirm that the `domain` of the deep link is on the wallet's allowlist. 🚨 This is crucial for authenticating from a trusted wallet. +1. Fetch the SDP's toml file at `{domain}/.well-known/stellar.toml` and confirm the `SIGNING_KEY` variable is populated. +1. Verify that the registration link signature was made using `SIGNING_KEY` similar to the `keypairPk.verify(...)` function in the snippet above. +1. Check the `asset` from the link and confirm that the recipient user has a trustline for that asset. Create one if it doesn't exist. +1. (Optional) Use the `name` from the link to update the wallet user interface. +1. Initiate the [SEP-24] deposit flow with that asset using the `TRANSFER_SERVER_SEP0024` value from the SDP's toml file. + - This includes using [SEP-10] to authenticate the user with the SDP's server and implementing the `client_domain` check, as detailed in the [SEP-10] spec. +1. Launch the deposit flow interactive *in-app browser* within your mobile app, following the instructions in the [SEP-24] spec. + - ATTENTION: the wallet should not, in any circumstances, scrape or attempt to scrape the content from the *in-app browser* for the recipient's information. + - NOTE: It's highly recommended to use an *in-app browser* rather than a webview. +1. πŸŽ‰ Congratulations! The recipient user can now fill out the forms in the *in-app browser* and register to receive their payment πŸŽ‰. + +Additionally, the wallet should save the link and/or link attributes and associate it with the individual receiving user for these reasons: + +1. This is how the wallet will know that the user is associated with a certain org or SDP. +1. Saving the data is useful for reporting and troubleshooting, especially if the wallet needs to justify the source of funds for regulatory or tax purposes. Additionally, if the payer org wants to pay any cashout fees charged by the wallet or offramp, the wallet will need to know which users and transactions should be invoiced upstream. + +## Contributors + +This section is a work-in-progress. + +### State Transitions + +The state transitions of a disbursement, payment, message, and wallet (i.e. recipient Stellar account) are described below. + +#### Disbursements + +```mermaid +stateDiagram-v2 + [*] --> Draft:Started creating the disbursement + Draft --> [*]:User deleted\nthe draft + Draft --> Draft:File Ingestion failed\n due to wrong data + Draft --> Ready:Upload + Ready --> Started:User Started Disbursement\n in the Dashboard + Started --> Paused:Paused + Paused --> Started:Unpaused + Started --> Completed:All payments\n went through +``` + +#### Payments + +```mermaid +stateDiagram-v2 + [*] --> Draft:Upload a disbursement CSV + Draft --> [*]:Disbursement deleted + Draft --> Ready:Disbursement started + Ready --> Paused:Paused + Paused --> Ready:Unpaused + Ready --> Pending:Payment gets submitted\nif user is ready + Pending --> Success:Payment succeeds + Pending --> Failed:Payment fails + Failed --> Pending:Retry +``` + +#### Recipient Wallets + +```mermaid +stateDiagram-v2 + [*] --> Draft:Upload disbursement CSV + Draft --> [*]:disbursement deleted + Draft --> Ready: Disbursement started + Ready --> Registered: receiver signed up + Ready --> Flagged: flagged + Flagged --> Ready: unflagged + Registered --> Flagged: flagged + Flagged --> Registered: unflagged +``` + +#### Messages + +```mermaid +stateDiagram-v2 + [*] --> Pending: Message is queued + Pending --> Success:Message sender\nAPI succeeds + Pending --> Failed:Message sender\nAPI fails + Failed --> Pending:Retry +``` + +[deferred deep linking]: https://en.wikipedia.org/wiki/Mobile_deep_linking#Deferred_deep_linking +[deep link]: https://en.wikipedia.org/wiki/Mobile_deep_linking +[SEP-10]: https://stellar.org/protocol/sep-10 +[SEP-24]: https://stellar.org/protocol/sep-24 +[sdp-dashboard]: https://github.com/stellar/stellar-disbursement-platform-frontend diff --git a/cmd/auth.go b/cmd/auth.go new file mode 100644 index 000000000..921f9aa83 --- /dev/null +++ b/cmd/auth.go @@ -0,0 +1,147 @@ +package cmd + +import ( + "fmt" + "go/types" + "net/url" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + di "github.com/stellar/stellar-disbursement-platform-backend/internal/dependencyinjection" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/cli" +) + +type AuthCommand struct{} + +func (a *AuthCommand) Command() *cobra.Command { + var uiBaseURL string + messengerOptions := message.MessengerOptions{} + + authCmdConfigOpts := config.ConfigOptions{ + { + Name: "sdp-ui-base-url", + Usage: "The SDP UI Base URL used to send the invitation link when a new user is created.", + OptType: types.String, + ConfigKey: &uiBaseURL, + FlagDefault: "http://localhost:3000", + CustomSetValue: cmdUtils.SetConfigOptionURLString, + Required: true, + }, + { + Name: "email-sender-type", + Usage: fmt.Sprintf("The messenger type used to send invitations to new dashboard users. Options: %+v", message.MessengerType("").ValidEmailTypes()), + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionMessengerType, + ConfigKey: &messengerOptions.MessengerType, + FlagDefault: string(message.MessengerTypeDryRun), + Required: true, + }, + } + authCmdConfigOpts = append(authCmdConfigOpts, cmdUtils.TwilioConfigOptions(&messengerOptions)...) + authCmdConfigOpts = append(authCmdConfigOpts, cmdUtils.AWSConfigOptions(&messengerOptions)...) + + var emailMessengerClient message.MessengerClient + + // Auth Module sub-commands + availableRoles := data.FromUserRoleArrayToStringArray(data.GetAllRoles()) + addUserSubcommand := cli.AddUserCmd(dbConfigOptionFlagName, cli.NewDefaultPasswordPrompt(), availableRoles) + + authCmd := &cobra.Command{ + Use: "auth", + Short: "Stellar Auth helpers", + Example: "auth ", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + authCmdConfigOpts.Require() + err := authCmdConfigOpts.SetValues() + if err != nil { + log.Fatalf("error setting values of config options: %s", err.Error()) + } + + if cmd.Name() == addUserSubcommand.Name() && !viper.GetBool("password") { + emailOptions := di.EmailClientOptions{EmailType: messengerOptions.MessengerType, MessengerOptions: &messengerOptions} + emailMessengerClient, err = di.NewEmailClient(emailOptions) + if err != nil { + log.Ctx(ctx).Fatalf("error creating dashboard user client: %s", err.Error()) + } + } + }, + Run: func(cmd *cobra.Command, args []string) { + if err := cmd.Help(); err != nil { + log.Fatalf("Error calling auth command: %s", err.Error()) + } + }, + PersistentPostRun: func(cmd *cobra.Command, args []string) { + // If the user was registered without set the password. We should + // send the invitation email. + if cmd.Name() == addUserSubcommand.Name() && !viper.GetBool("password") { + ctx := cmd.Context() + + email, firstName := args[0], args[1] + + // We don't need to validate the content since it was already validated + // in the stellar-auth + role := viper.GetString("roles") + + forgotPasswordLink, err := url.JoinPath(uiBaseURL, "forgot-password") + if err != nil { + log.Ctx(ctx).Fatalf("error getting forgot password link: %s", err.Error()) + } + + dbConnectionPool, err := db.OpenDBConnectionPool(globalOptions.databaseURL) + if err != nil { + log.Ctx(ctx).Fatalf("error getting database connection: %s", err.Error()) + } + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + if err != nil { + log.Ctx(ctx).Fatalf("error getting models: %s", err.Error()) + } + + organization, err := models.Organizations.Get(ctx) + if err != nil { + log.Ctx(ctx).Fatalf("error getting organization data: %s", err.Error()) + } + + invitationData := htmltemplate.InvitationMessageTemplate{ + FirstName: firstName, + Role: role, + ForgotPasswordLink: forgotPasswordLink, + OrganizationName: organization.Name, + } + + msgBody, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(invitationData) + if err != nil { + log.Ctx(ctx).Fatalf("error executing invitation message template: %s", err.Error()) + } + + err = emailMessengerClient.SendMessage(message.Message{ + ToEmail: email, + Title: "Welcome to Stellar Disbursement Platform", + Message: msgBody, + }) + if err != nil { + log.Ctx(ctx).Fatalf("error sending invitation message: %s", err.Error()) + } + } + }, + } + + if err := authCmdConfigOpts.Init(authCmd); err != nil { + log.Fatalf("error initializing authCmd config options: %s", err.Error()) + } + + authCmd.AddCommand(addUserSubcommand) + + return authCmd +} diff --git a/cmd/auth_test.go b/cmd/auth_test.go new file mode 100644 index 000000000..ac68219f9 --- /dev/null +++ b/cmd/auth_test.go @@ -0,0 +1,187 @@ +package cmd + +import ( + "io" + "os" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_persistentPostRun(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + t.Setenv("DATABASE_URL", dbt.DSN) + t.Setenv("EMAIL_SENDER_TYPE", "DRY_RUN") + + addUserCmdMock := &cobra.Command{ + Use: "add-user [--password] [--owner]", + Args: cobra.ExactArgs(3), + Run: func(cmd *cobra.Command, args []string) { + assert.Equal(t, []string{"email@email.com", "First", "Last"}, args) + }, + } + + addUserCmdMock.PersistentFlags().String("roles", "", "") + err := viper.BindPFlag("roles", addUserCmdMock.PersistentFlags().Lookup("roles")) + require.NoError(t, err) + + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs([]string{"auth", "add-user", "email@email.com", "First", "Last", "--roles", "developer"}) + + for _, cmd := range rootCmd.Commands() { + if cmd.Name() == "auth" { + for _, authCmd := range cmd.Commands() { + if authCmd.Name() == "add-user" { + cmd.RemoveCommand(authCmd) + cmd.AddCommand(addUserCmdMock) + break + } + } + break + } + } + + stdOut := os.Stdout + + r, w, err := os.Pipe() + require.NoError(t, err) + + os.Stdout = w + + err = rootCmd.Execute() + require.NoError(t, err) + + expectContains := `------------------------------------------------------------------------------- +Recipient: email@email.com +Subject: Welcome to Stellar Disbursement Platform +Content: + + + + Welcome to Stellar Disbursement Platform + + + +

Hello, First!

+

You have been added to your organization's Stellar Disbursement Platform as a developer. Please click the link below to set up your password and let your organization administrator know if you have any questions.

+

+ Set up my password +

+

Best regards,

+

The MyCustomAid Team

+ + + +------------------------------------------------------------------------------- +` + + w.Close() + os.Stdout = stdOut + + buf := new(strings.Builder) + _, err = io.Copy(buf, r) + require.NoError(t, err) + + assert.Contains(t, buf.String(), expectContains) + + // Set another SDP UI base URL + rootCmd.SetArgs([]string{"auth", "add-user", "email@email.com", "First", "Last", "--roles", "developer", "--sdp-ui-base-url", "https://sdp-ui.org"}) + + stdOut = os.Stdout + + r, w, err = os.Pipe() + require.NoError(t, err) + + os.Stdout = w + + err = rootCmd.Execute() + require.NoError(t, err) + + expectContains = `------------------------------------------------------------------------------- +Recipient: email@email.com +Subject: Welcome to Stellar Disbursement Platform +Content: + + + + Welcome to Stellar Disbursement Platform + + + +

Hello, First!

+

You have been added to your organization's Stellar Disbursement Platform as a developer. Please click the link below to set up your password and let your organization administrator know if you have any questions.

+

+ Set up my password +

+

Best regards,

+

The MyCustomAid Team

+ + + +------------------------------------------------------------------------------- +` + + w.Close() + os.Stdout = stdOut + + buf.Reset() + _, err = io.Copy(buf, r) + require.NoError(t, err) + + assert.Contains(t, buf.String(), expectContains) +} diff --git a/cmd/channel_accounts.go b/cmd/channel_accounts.go new file mode 100644 index 000000000..88ceea0bb --- /dev/null +++ b/cmd/channel_accounts.go @@ -0,0 +1,364 @@ +package cmd + +import ( + "go/types" + + "github.com/spf13/cobra" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + di "github.com/stellar/stellar-disbursement-platform-backend/internal/dependencyinjection" + txSubSvc "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/services" +) + +type ChannelAccountsCommand struct { + Service txSubSvc.ChannelAccountsServiceInterface + CrashTrackerClient crashtracker.CrashTrackerClient +} + +func (c *ChannelAccountsCommand) Command() *cobra.Command { + svcOpts := &txSubSvc.ChannelAccountServiceOptions{} + crashTrackerOptions := crashtracker.CrashTrackerOptions{} + + configOpts := config.ConfigOptions{ + { + Name: "horizon-url", + Usage: `Horizon URL"`, + OptType: types.String, + ConfigKey: &svcOpts.HorizonUrl, + FlagDefault: horizonclient.DefaultTestNetClient.HorizonURL, + Required: true, + }, + { + Name: "crash-tracker-type", + Usage: `Crash tracker type. Options: "SENTRY", "DRY_RUN"`, + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionCrashTrackerType, + ConfigKey: &crashTrackerOptions.CrashTrackerType, + FlagDefault: "DRY_RUN", + Required: true, + }, + } + channelAccountsCmd := &cobra.Command{ + Use: "channel-accounts", + Short: "Channel accounts related commands", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + ctx := cmd.Context() + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + + // Inject server dependencies + svcOpts.DatabaseDSN = globalOptions.databaseURL + svcOpts.NetworkPassphrase = globalOptions.networkPassphrase + + c.Service, err = txSubSvc.NewChannelAccountService(*svcOpts) + if err != nil { + log.Ctx(ctx).Fatalf("Error creating channel account service: %s", err.Error()) + } + + // Inject crash tracker options dependencies + globalOptions.populateCrashTrackerOptions(&crashTrackerOptions) + + // Setup default Crash Tracker client + crashTrackerClient, err := di.NewCrashTracker(ctx, crashTrackerOptions) + if err != nil { + log.Ctx(ctx).Fatalf("Error creating crash tracker client: %s", err.Error()) + } + c.CrashTrackerClient = crashTrackerClient + }, + } + err := configOpts.Init(channelAccountsCmd) + if err != nil { + log.Fatalf("Error initializing channelAccountsCmd config option: %s", err.Error()) + } + + createCmd := c.CreateCommand(svcOpts) + deleteCmd := c.DeleteCommand(svcOpts) + ensureCmd := c.EnsureCommand(svcOpts) + verifyCmd := c.VerifyCommand(svcOpts) + viewCmd := c.ViewCommand() + channelAccountsCmd.AddCommand(createCmd, deleteCmd, ensureCmd, verifyCmd, viewCmd) + + return channelAccountsCmd +} + +func (c *ChannelAccountsCommand) CreateCommand(toolOpts *txSubSvc.ChannelAccountServiceOptions) *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "distribution-seed", + Usage: "The private key of the Stellar account that will be used to sponsor the channel accounts", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &toolOpts.RootSeed, + Required: true, + }, + { + Name: "num-channel-accounts-create", + Usage: "The desired number of channel accounts to be created", + OptType: types.Int, + ConfigKey: &toolOpts.NumChannelAccounts, + FlagDefault: 1, + Required: true, + }, + { + Name: "max-base-fee", + Usage: "The max base fee for submitting a stellar transaction", + OptType: types.Int, + ConfigKey: &toolOpts.MaxBaseFee, + FlagDefault: txnbuild.MinBaseFee, + Required: true, + }, + { + Name: "encrypt-key", + Usage: "Whether or not to encrypt the private key for storage", + OptType: types.Bool, + ConfigKey: &toolOpts.EncryptKey, + FlagDefault: true, + Required: true, + }, + // { // TODO - actually use this - not needed for SDP's TSS + // Name: "output", + // Usage: "where to output the channel accounts (database or csv file)", + // OptType: types.String, + // CustomSetValue: ?, + // ConfigKey: &submitterOpts.DistributionPublicKey, + // FlagDefault: "database", + // Required: false, + // }, + } + createCmd := &cobra.Command{ + Use: "create", + Short: "Create channel accounts", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + + // entrypoint into the main logic for creating channel accounts + if err := c.Service.CreateChannelAccountsOnChain(ctx, *toolOpts); err != nil { + c.CrashTrackerClient.LogAndReportErrors(ctx, err, "Cmd channel-accounts create crash") + log.Ctx(ctx).Fatalf("Error creating channel accounts: %s", err.Error()) + } + }, + } + err := configOpts.Init(createCmd) + if err != nil { + log.Fatalf("Error initializing createCmd: %s", err.Error()) + } + + return createCmd +} + +func (c *ChannelAccountsCommand) VerifyCommand(toolOpts *txSubSvc.ChannelAccountServiceOptions) *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "delete-invalid-accounts", + Usage: "Delete channel accounts from storage that are verified to be invalid on the network", + OptType: types.Bool, + ConfigKey: &toolOpts.DeleteInvalidAcccounts, + FlagDefault: false, + Required: false, + }, + } + + verifyCmd := &cobra.Command{ + Use: "verify", + Short: "Verify the existence of all channel accounts in the database on the Stellar newtwork", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + if err := c.Service.VerifyChannelAccounts(ctx, *toolOpts); err != nil { + c.CrashTrackerClient.LogAndReportErrors(ctx, err, "Cmd channel-accounts verify crash") + log.Ctx(ctx).Fatalf("Error verifying channel accounts: %s", err.Error()) + } + }, + } + err := configOpts.Init(verifyCmd) + if err != nil { + log.Fatalf("Error initializing verifyCmd: %s", err.Error()) + } + + return verifyCmd +} + +func (c *ChannelAccountsCommand) EnsureCommand(toolOpts *txSubSvc.ChannelAccountServiceOptions) *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "distribution-seed", + Usage: "The private key of the Stellar account used to sponsor existing channel accounts", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &toolOpts.RootSeed, + Required: true, + }, + { + Name: "num-channel-accounts-ensure", + Usage: "The desired number of channel accounts to manage", + OptType: types.Int, + ConfigKey: &toolOpts.NumChannelAccounts, + FlagDefault: 1, + Required: true, + }, + { + Name: "max-base-fee", + Usage: "The max base fee for submitting a stellar transaction", + OptType: types.Int, + ConfigKey: &toolOpts.MaxBaseFee, + FlagDefault: txnbuild.MinBaseFee, + Required: true, + }, + { + Name: "encrypt-key", + Usage: "Whether or not to encrypt the private key for storage", + OptType: types.Bool, + ConfigKey: &toolOpts.EncryptKey, + FlagDefault: true, + Required: true, + }, + } + + ensureCmd := &cobra.Command{ + Use: "ensure", + Short: "Ensure we are managing exactly the number of channel accounts " + + "equal to some specified count by dynamically increasing or decreasing the number of managed " + + "channel accounts in storage and onchain", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + if err := c.Service.EnsureChannelAccountsCount(ctx, *toolOpts); err != nil { + c.CrashTrackerClient.LogAndReportErrors(ctx, err, "Cmd channel-accounts ensure crash") + log.Ctx(ctx).Fatalf("Error ensuring count for channel accounts: %s", err.Error()) + } + }, + } + + err := configOpts.Init(ensureCmd) + if err != nil { + log.Fatalf("Error initializing ensureCmd: %s", err.Error()) + } + + return ensureCmd +} + +func (c *ChannelAccountsCommand) DeleteCommand(toolOpts *txSubSvc.ChannelAccountServiceOptions) *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "distribution-seed", + Usage: "The private key of the Stellar account used to sponsor the channel account specified", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &toolOpts.RootSeed, + Required: true, + }, + { + Name: "channel-account-id", + Usage: "The ID of the channel account to delete", + OptType: types.String, + ConfigKey: &toolOpts.ChannelAccountID, + Required: false, + }, + { + Name: "delete-all-accounts", + Usage: "Delete all managed channel accoounts in the database and on the network", + OptType: types.Bool, + ConfigKey: &toolOpts.DeleteAllAccounts, + FlagDefault: false, + Required: false, + }, + { + Name: "max-base-fee", + Usage: "The max base fee for submitting a stellar transaction", + OptType: types.Int, + ConfigKey: &toolOpts.MaxBaseFee, + FlagDefault: txnbuild.MinBaseFee, + Required: true, + }, + } + + deleteCmd := &cobra.Command{ + Use: "delete", + Short: "Delete a specified channel account from storage and on the network", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + if err := c.Service.DeleteChannelAccount(ctx, *toolOpts); err != nil { + c.CrashTrackerClient.LogAndReportErrors(ctx, err, "Cmd channel-accounts delete crash") + log.Ctx(ctx).Fatalf("Error deleting channel account: %s", err.Error()) + } + }, + } + + err := configOpts.Init(deleteCmd) + if err != nil { + log.Fatalf("Error initializing deleteCmd: %s", err.Error()) + } + + deleteCmd.MarkFlagsMutuallyExclusive("channel-account-id", "delete-all-accounts") + + return deleteCmd +} + +func (c *ChannelAccountsCommand) ViewCommand() *cobra.Command { + viewCmd := &cobra.Command{ + Use: "view", + Short: "View all channel accounts currently managed in the database", + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + err := c.Service.ViewChannelAccounts(ctx) + if err != nil { + c.CrashTrackerClient.LogAndReportErrors(ctx, err, "Cmd channel-accounts view crash") + log.Ctx(ctx).Fatalf("Error viewing channel accounts: %s", err.Error()) + } + }, + } + + return viewCmd +} diff --git a/cmd/channel_accounts_test.go b/cmd/channel_accounts_test.go new file mode 100644 index 000000000..c47071343 --- /dev/null +++ b/cmd/channel_accounts_test.go @@ -0,0 +1,392 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "strconv" + "testing" + + "github.com/spf13/cobra" + "github.com/stellar/go/keypair" + "github.com/stellar/go/txnbuild" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + txSubSvc "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/services" +) + +func Test_ChannelAccountsCommand_Command(t *testing.T) { + dbt := dbtest.Open(t) + + caCommand := &ChannelAccountsCommand{} + + root := rootCmd() + cmd := caCommand.Command() + root.AddCommand(cmd) + + root.SetArgs([]string{ + "channel-accounts", + "verify", + "--database-url", + dbt.DSN, + }) + err := cmd.Execute() + require.NoError(t, err) +} + +func Test_ChannelAccountsCommand_CreateCommand(t *testing.T) { + caServiceMock := &txSubSvc.ChannelAccountsServiceMock{} + crashTrackerMock := &crashtracker.MockCrashTrackerClient{} + caCommand := &ChannelAccountsCommand{ + Service: caServiceMock, + CrashTrackerClient: crashTrackerMock, + } + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + cmd := caCommand.CreateCommand(&txSubSvc.ChannelAccountServiceOptions{}) + parentCmdMock.AddCommand(cmd) + + distributionSeed := keypair.MustRandom().Seed() + encryptKey := true + + parentCmdMock.SetArgs([]string{ + "create", + "--distribution-seed", + distributionSeed, + "--num-channel-accounts-create", + "2", + "--encrypt-key", + strconv.FormatBool(encryptKey), + }) + + t.Run("exit with status 1 when ChannelAccountsService fails", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + customErr := errors.New("unexpected error") + caServiceMock. + On("CreateChannelAccountsOnChain", context.Background(), txSubSvc.ChannelAccountServiceOptions{ + NumChannelAccounts: 2, + MaxBaseFee: txnbuild.MinBaseFee, + RootSeed: distributionSeed, + EncryptKey: encryptKey, + }). + Return(customErr) + crashTrackerMock.On("LogAndReportErrors", context.Background(), customErr, "Cmd channel-accounts create crash") + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executes the create command successfully", func(t *testing.T) { + caServiceMock. + On("CreateChannelAccountsOnChain", context.Background(), txSubSvc.ChannelAccountServiceOptions{ + NumChannelAccounts: 2, + MaxBaseFee: txnbuild.MinBaseFee, + RootSeed: distributionSeed, + EncryptKey: encryptKey, + }). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + caServiceMock.AssertExpectations(t) + crashTrackerMock.AssertExpectations(t) +} + +func Test_ChannelAccountsCommand_VerifyCommand(t *testing.T) { + caServiceMock := &txSubSvc.ChannelAccountsServiceMock{} + crashTrackerMock := &crashtracker.MockCrashTrackerClient{} + caCommand := &ChannelAccountsCommand{Service: caServiceMock} + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + cmd := caCommand.VerifyCommand(&txSubSvc.ChannelAccountServiceOptions{}) + parentCmdMock.AddCommand(cmd) + + parentCmdMock.SetArgs([]string{ + "verify", + }) + + t.Run("exit with status 1 when ChannelAccountsService fails", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + customErr := errors.New("unexpected error") + caServiceMock. + On("VerifyChannelAccounts", context.Background()). + Return(customErr) + crashTrackerMock.On("LogAndReportErrors", context.Background(), customErr, "Cmd channel-accounts verify crash") + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executes the verify command successfully", func(t *testing.T) { + caServiceMock. + On("VerifyChannelAccounts", context.Background()). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + caServiceMock.AssertExpectations(t) + crashTrackerMock.AssertExpectations(t) +} + +func Test_ChannelAccountsCommand_EnsureCommand(t *testing.T) { + caServiceMock := &txSubSvc.ChannelAccountsServiceMock{} + crashTrackerMock := &crashtracker.MockCrashTrackerClient{} + caCommand := &ChannelAccountsCommand{Service: caServiceMock} + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + cmd := caCommand.EnsureCommand(&txSubSvc.ChannelAccountServiceOptions{}) + parentCmdMock.AddCommand(cmd) + + distributionSeed := keypair.MustRandom().Seed() + encryptKey := true + + parentCmdMock.SetArgs([]string{ + "ensure", + "--distribution-seed", + distributionSeed, + "--num-channel-accounts-ensure", + "2", + "--encrypt-key", + strconv.FormatBool(encryptKey), + }) + + t.Run("exit with status 1 when ChannelAccountsService fails", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + customErr := errors.New("unexpected error") + caServiceMock. + On("EnsureChannelAccountsCount", context.Background(), txSubSvc.ChannelAccountServiceOptions{ + MaxBaseFee: txnbuild.MinBaseFee, + NumChannelAccounts: 2, + RootSeed: distributionSeed, + EncryptKey: encryptKey, + }). + Return(customErr) + crashTrackerMock.On("LogAndReportErrors", context.Background(), customErr, "Cmd channel-accounts ensure crash") + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executs the ensure command successfully", func(t *testing.T) { + caServiceMock. + On("EnsureChannelAccountsCount", context.Background(), txSubSvc.ChannelAccountServiceOptions{ + MaxBaseFee: txnbuild.MinBaseFee, + NumChannelAccounts: 2, + RootSeed: distributionSeed, + EncryptKey: encryptKey, + }). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + caServiceMock.AssertExpectations(t) + crashTrackerMock.AssertExpectations(t) +} + +func Test_ChannelAccountsCommand_DeleteCommand(t *testing.T) { + caServiceMock := &txSubSvc.ChannelAccountsServiceMock{} + crashTrackerMock := &crashtracker.MockCrashTrackerClient{} + caCommand := &ChannelAccountsCommand{Service: caServiceMock} + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + cmd := caCommand.DeleteCommand(&txSubSvc.ChannelAccountServiceOptions{}) + parentCmdMock.AddCommand(cmd) + + distributionSeed := keypair.MustRandom().Seed() + + args := []string{ + "delete", + "--distribution-seed", + distributionSeed, + "--channel-account-id", + "acc-id", + } + + t.Run("exit with status 1 when ChannelAccountsService fails", func(t *testing.T) { + parentCmdMock.SetArgs(args) + customErr := errors.New("unexpected error") + if os.Getenv("TEST_FATAL") == "1" { + caServiceMock. + On("DeleteChannelAccount", context.Background(), txSubSvc.ChannelAccountServiceOptions{ + MaxBaseFee: txnbuild.MinBaseFee, + ChannelAccountID: "acc-id", + RootSeed: distributionSeed, + }). + Return(customErr) + crashTrackerMock.On("LogAndReportErrors", context.Background(), customErr, "Cmd channel-accounts delete crash") + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executes the delete command successfully", func(t *testing.T) { + parentCmdMock.SetArgs(args) + caServiceMock. + On("DeleteChannelAccount", context.Background(), txSubSvc.ChannelAccountServiceOptions{ + MaxBaseFee: txnbuild.MinBaseFee, + ChannelAccountID: "acc-id", + RootSeed: distributionSeed, + }). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + t.Run("delete command fails when both channel-account-id and delete-all-accounts are set", func(t *testing.T) { + parentCmdMock.SetArgs(append(args, "--delete-all-accounts")) + + err := parentCmdMock.Execute() + require.EqualError( + t, + err, + "if any flags in the group [channel-account-id delete-all-accounts] are set none of the others can be; [channel-account-id delete-all-accounts] were all set", + ) + }) + + caServiceMock.AssertExpectations(t) + crashTrackerMock.AssertExpectations(t) +} + +func Test_ChannelAccountsCommand_ViewCommand(t *testing.T) { + caServiceMock := &txSubSvc.ChannelAccountsServiceMock{} + crashTrackerMock := &crashtracker.MockCrashTrackerClient{} + caCommand := &ChannelAccountsCommand{Service: caServiceMock} + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + cmd := caCommand.ViewCommand() + parentCmdMock.AddCommand(cmd) + + parentCmdMock.SetArgs([]string{ + "view", + }) + + t.Run("exit with status 1 when ChannelAccountsService fails", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + customErr := errors.New("unexpected error") + caServiceMock. + On("ViewChannelAccounts", context.Background()). + Return(errors.New("unexpected error")) + crashTrackerMock.On("LogAndReportErrors", context.Background(), customErr, "Cmd channel-accounts view crash") + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executes the view command successfully", func(t *testing.T) { + caServiceMock. + On("ViewChannelAccounts", context.Background()). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + caServiceMock.AssertExpectations(t) + crashTrackerMock.AssertExpectations(t) +} diff --git a/cmd/db.go b/cmd/db.go new file mode 100644 index 000000000..0ac046b5d --- /dev/null +++ b/cmd/db.go @@ -0,0 +1,141 @@ +package cmd + +import ( + "fmt" + "strconv" + + migrate "github.com/rubenv/sql-migrate" + "github.com/spf13/cobra" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/cli" +) + +type DatabaseCommand struct{} + +func (c *DatabaseCommand) Command() *cobra.Command { + cmd := &cobra.Command{ + Use: "db", + Short: "Database related commands", + Run: func(cmd *cobra.Command, _ []string) { + err := cmd.Help() + if err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + }, + } + + migrateCmd := &cobra.Command{ + Use: "migrate", + Short: "Schema migration helpers", + Run: func(cmd *cobra.Command, _ []string) { + err := cmd.Help() + if err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + }, + } + cmd.AddCommand(migrateCmd) + + migrateUp := &cobra.Command{ + Use: "up", + Short: "Migrates database up [count]", + Args: cobra.MaximumNArgs(1), + Run: func(cmd *cobra.Command, args []string) { + var count int + if len(args) > 0 { + var err error + count, err = strconv.Atoi(args[0]) + if err != nil { + log.Fatalf("Invalid [count] argument: %s", args[0]) + } + } + + err := c.migrate(migrate.Up, count) + if err != nil { + log.Fatalf("Error migrating database Up: %s", err.Error()) + } + }, + } + migrateCmd.AddCommand(migrateUp) + + migrateDown := &cobra.Command{ + Use: "down [count]", + Short: "Migrates database down [count] migrations", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + count, err := strconv.Atoi(args[0]) + if err != nil { + log.Fatalf("Invalid [count] argument: %s", args[0]) + } + + err = c.migrate(migrate.Down, count) + if err != nil { + log.Fatalf("Error migrating database Down: %s", err.Error()) + } + }, + } + migrateCmd.AddCommand(migrateDown) + + setupForNetwork := &cobra.Command{ + Use: "setup-for-network", + Short: "Set up the assets and wallets registered in the database based on the network passphrase.", + Long: "Set up the assets and wallets registered in the database based on the network passphrase. It inserts or updates the entries of these tables according with the configured Network Passphrase.", + Run: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + + dbConnectionPool, err := db.OpenDBConnectionPool(globalOptions.databaseURL) + if err != nil { + log.Ctx(ctx).Fatalf("error connection to the database: %s", err.Error()) + } + defer dbConnectionPool.Close() + + networkType, err := utils.GetNetworkTypeFromNetworkPassphrase(globalOptions.networkPassphrase) + if err != nil { + log.Ctx(ctx).Fatalf("error getting network type: %s", err.Error()) + } + + if err := services.SetupWalletsForProperNetwork(ctx, dbConnectionPool, networkType, services.DefaultWalletsNetworkMap); err != nil { + log.Ctx(ctx).Fatalf("error upserting wallets for proper network: %s", err.Error()) + } + + if err := services.SetupAssetsForProperNetwork(ctx, dbConnectionPool, networkType, services.DefaultAssetsNetworkMap); err != nil { + log.Ctx(ctx).Fatalf("error upserting assets for proper network: %s", err.Error()) + } + }, + } + cmd.AddCommand(setupForNetwork) + + stellarAuthMigrateCmd := &cobra.Command{ + Use: "auth", + Short: "Stellar Auth schema migration helpers", + Example: "stellarauth migrate [direction] [count]", + Run: func(cmd *cobra.Command, args []string) { + if err := cmd.Help(); err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + }, + } + stellarAuthMigrateCmd.AddCommand(cli.MigrateCmd(dbConfigOptionFlagName)) + + // Add `auth` as a sub-command to `db`. Usage: db auth migrate up + cmd.AddCommand(stellarAuthMigrateCmd) + + return cmd +} + +func (c *DatabaseCommand) migrate(dir migrate.MigrationDirection, count int) error { + numMigrationsRun, err := db.Migrate(globalOptions.databaseURL, dir, count) + if err != nil { + return fmt.Errorf("migrating database: %w", err) + } + + if numMigrationsRun == 0 { + log.Info("No migrations applied.") + } else { + log.Infof("Successfully applied %d migrations.", numMigrationsRun) + } + return nil +} diff --git a/cmd/db_test.go b/cmd/db_test.go new file mode 100644 index 000000000..daf9af3a2 --- /dev/null +++ b/cmd/db_test.go @@ -0,0 +1,236 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getMigrationsApplied(t *testing.T, ctx context.Context, db db.DBConnectionPool) []string { + rows, err := db.QueryContext(ctx, "SELECT id FROM gorp_migrations") + require.NoError(t, err) + defer rows.Close() + + ids := []string{} + for rows.Next() { + var id string + err := rows.Scan(&id) + require.NoError(t, err) + + ids = append(ids, id) + } + + require.NoError(t, rows.Err()) + + return ids +} + +func Test_DatabaseCommand_db_help(t *testing.T) { + buf := new(strings.Builder) + + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs([]string{"db"}) + rootCmd.SetOut(buf) + err := rootCmd.Execute() + require.NoError(t, err) + + expectedContains := []string{ + "Database related commands", + "stellar-disbursement-platform db [flags]", + "stellar-disbursement-platform db [command]", + "auth Stellar Auth schema migration helpers", + "migrate Schema migration helpers", + "setup-for-network Set up the assets and wallets registered in the database based on the network passphrase.", + "-h, --help help for db", + `--base-url string The SDP UI base URL. (BASE_URL) (default "http://localhost:8000")`, + `--database-url string Postgres DB URL (DATABASE_URL) (default "postgres://localhost:5432/sdp?sslmode=disable")`, + `--environment string The environment where the application is running. Example: "development", "staging", "production". (ENVIRONMENT) (default "development")`, + `--log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE")`, + `--network-passphrase string The Stellar network passphrase (NETWORK_PASSPHRASE) (default "Test SDF Network ; September 2015")`, + `--sentry-dsn string The DSN (client key) of the Sentry project. If not provided, Sentry will not be used. (SENTRY_DSN)`, + } + + output := buf.String() + for _, expected := range expectedContains { + assert.Contains(t, output, expected) + } + + buf.Reset() + rootCmd.SetArgs([]string{"db", "--help"}) + err = rootCmd.Execute() + require.NoError(t, err) + + output = buf.String() + for _, expected := range expectedContains { + assert.Contains(t, output, expected) + } +} + +func Test_DatabaseCommand_db_migrate(t *testing.T) { + dbt := dbtest.OpenWithoutMigrations(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + buf := new(strings.Builder) + + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs([]string{"db", "migrate"}) + rootCmd.SetOut(buf) + err = rootCmd.Execute() + require.NoError(t, err) + + expectedContains := []string{ + "Schema migration helpers", + "stellar-disbursement-platform db migrate [flags]", + "stellar-disbursement-platform db migrate [command]", + "down Migrates database down [count] migrations", + "up Migrates database up [count]", + "-h, --help help for migrate", + `--base-url string The SDP UI base URL. (BASE_URL) (default "http://localhost:8000")`, + `--database-url string Postgres DB URL (DATABASE_URL) (default "postgres://localhost:5432/sdp?sslmode=disable")`, + `--environment string The environment where the application is running. Example: "development", "staging", "production". (ENVIRONMENT) (default "development")`, + `--log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE")`, + `--network-passphrase string The Stellar network passphrase (NETWORK_PASSPHRASE) (default "Test SDF Network ; September 2015")`, + `--sentry-dsn string The DSN (client key) of the Sentry project. If not provided, Sentry will not be used. (SENTRY_DSN)`, + } + + output := buf.String() + for _, expected := range expectedContains { + assert.Contains(t, output, expected) + } + + buf.Reset() + log.DefaultLogger.SetOutput(buf) + rootCmd = SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs([]string{"db", "migrate", "up", "1", "--database-url", dbt.DSN, "--log-level", "TRACE"}) + err = rootCmd.Execute() + require.NoError(t, err) + + ids := getMigrationsApplied(t, context.Background(), dbConnectionPool) + assert.Equal(t, []string{"2023-01-20.0-initial.sql"}, ids) + + assert.Contains(t, buf.String(), "Successfully applied 1 migrations.") + + buf.Reset() + rootCmd = SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs([]string{"db", "migrate", "down", "1", "--database-url", dbt.DSN, "--log-level", "TRACE"}) + err = rootCmd.Execute() + require.NoError(t, err) + + ids = getMigrationsApplied(t, context.Background(), dbConnectionPool) + assert.Equal(t, []string{}, ids) + + assert.Contains(t, buf.String(), "Successfully applied 1 migrations.") +} + +func Test_DatabaseCommand_db_setup_for_network(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + // Assets + testnetUSDCIssuer := keypair.MustRandom().Address() + data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", testnetUSDCIssuer) + + assets, err := models.Assets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, assets, 1) + assert.Equal(t, "USDC", assets[0].Code) + assert.Equal(t, testnetUSDCIssuer, assets[0].Issuer) + + // Wallets + data.CreateWalletFixture(t, ctx, dbConnectionPool, "Vibrant Assist", "https://vibrantapp.com", "api-dev.vibrantapp.com", "https://vibrantapp.com/sdp-dev") + + wallets, err := models.Wallets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, wallets, 1) + assert.Equal(t, "Vibrant Assist", wallets[0].Name) + assert.Equal(t, "https://vibrantapp.com", wallets[0].Homepage) + assert.Equal(t, "api-dev.vibrantapp.com", wallets[0].SEP10ClientDomain) + assert.Equal(t, "https://vibrantapp.com/sdp-dev", wallets[0].DeepLinkSchema) + + buf := new(strings.Builder) + log.DefaultLogger.SetLevel(log.InfoLevel) + log.DefaultLogger.SetOutput(buf) + + // Setup + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs([]string{ + "db", + "setup-for-network", + "--database-url", + dbt.DSN, + "--network-passphrase", + network.PublicNetworkPassphrase, + }) + + err = rootCmd.Execute() + require.NoError(t, err) + + // Validating assets + assets, err = models.Assets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, assets, 1) + assert.Equal(t, "USDC", assets[0].Code) + assert.NotEqual(t, testnetUSDCIssuer, assets[0].Issuer) + assert.Equal(t, services.DefaultAssetsNetworkMap[utils.PubnetNetworkType]["USDC"], assets[0].Issuer) + + // Validating wallets + wallets, err = models.Wallets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, wallets, 1) + // assert.Equal(t, "Beans App", wallets[0].Name) + // assert.Equal(t, "https://www.beansapp.com/disbursements", wallets[0].Homepage) + // assert.Equal(t, "api.beansapp.com", wallets[0].SEP10ClientDomain) + // assert.Equal(t, "https://www.beansapp.com/disbursements/registration?redirect=true", wallets[0].DeepLinkSchema) + assert.Equal(t, "Vibrant Assist", wallets[0].Name) + assert.Equal(t, "https://vibrantapp.com/assist", wallets[0].Homepage) + assert.Equal(t, "api.vibrantapp.com", wallets[0].SEP10ClientDomain) + assert.Equal(t, "https://vibrantapp.com/sdp", wallets[0].DeepLinkSchema) + + expectedLogs := []string{ + "updating/inserting assets for the 'pubnet' network", + "Code: USDC", + fmt.Sprintf("Issuer: %s", services.DefaultAssetsNetworkMap[utils.PubnetNetworkType]["USDC"]), + "updating/inserting wallets for the 'pubnet' network", + "Name: Vibrant Assist", + "Homepage: https://vibrantapp.com/assist", + "Deep Link Schema: https://vibrantapp.com/sdp", + "SEP-10 Client Domain: api.vibrantapp.com", + } + + logs := buf.String() + for _, expectedLog := range expectedLogs { + assert.Contains(t, logs, expectedLog) + } +} diff --git a/cmd/integration_tests.go b/cmd/integration_tests.go new file mode 100644 index 000000000..4a71fbda9 --- /dev/null +++ b/cmd/integration_tests.go @@ -0,0 +1,254 @@ +package cmd + +import ( + "go/types" + + "github.com/spf13/cobra" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/integrationtests" +) + +type IntegrationTestsCommand struct { + Service integrationtests.IntegrationTestsInterface +} + +func (c *IntegrationTestsCommand) Command() *cobra.Command { + integrationTestsOpts := &integrationtests.IntegrationTestsOpts{} + + configOpts := config.ConfigOptions{ + { + Name: "disbursed-asset-code", + Usage: "Code of the asset to be disbursed", + OptType: types.String, + ConfigKey: &integrationTestsOpts.DisbursedAssetCode, + Required: true, + }, + { + Name: "disbursed-asset-issuer", + Usage: "Issuer if the asset to be disbursed", + OptType: types.String, + ConfigKey: &integrationTestsOpts.DisbursetAssetIssuer, + Required: true, + }, + { + Name: "disbursement-name", + Usage: "Disbursement name to be used in integration tests", + OptType: types.String, + ConfigKey: &integrationTestsOpts.DisbursementName, + FlagDefault: "disbursement_integration_tests", + Required: true, + }, + { + Name: "wallet-name", + Usage: "Wallet name to be used in integration tests", + OptType: types.String, + ConfigKey: &integrationTestsOpts.WalletName, + FlagDefault: "Integration test wallet", + Required: true, + }, + } + integrationTestsCmd := &cobra.Command{ + Use: "integration-tests", + Short: "Integration tests related commands", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Fatalf("Error setting values of config options: %s", err.Error()) + } + + // inject database url to integration tests opts + integrationTestsOpts.DatabaseDSN = globalOptions.databaseURL + + c.Service, err = integrationtests.NewIntegrationTestsService(*integrationTestsOpts) + if err != nil { + log.Fatalf("error creating integration tests service: %s", err.Error()) + } + }, + } + err := configOpts.Init(integrationTestsCmd) + if err != nil { + log.Fatalf("Error initializing a config option: %s", err.Error()) + } + + startIntegrationTestsCmd := c.StartIntegrationTestsCommand(integrationTestsOpts) + createIntegrationTestsDataCmd := c.CreateIntegrationTestsDataCommand(integrationTestsOpts) + integrationTestsCmd.AddCommand(startIntegrationTestsCmd, createIntegrationTestsDataCmd) + + return integrationTestsCmd +} + +func (c *IntegrationTestsCommand) StartIntegrationTestsCommand(integrationTestsOpts *integrationtests.IntegrationTestsOpts) *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "user-email", + Usage: "Email from SDP authenticated user with all roles", + OptType: types.String, + ConfigKey: &integrationTestsOpts.UserEmail, + Required: true, + }, + { + Name: "user-password", + Usage: "Password from SDP authenticated user with all roles", + OptType: types.String, + ConfigKey: &integrationTestsOpts.UserPassword, + Required: true, + }, + { + Name: "receiver-account-public-key", + Usage: "Integration test receiver public stellar account key", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPublicKey, + ConfigKey: &integrationTestsOpts.ReceiverAccountPublicKey, + Required: true, + }, + { + Name: "receiver-account-private-key", + Usage: "Integration test receiver private stellar account key", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &integrationTestsOpts.ReceiverAccountPrivateKey, + Required: true, + }, + { + Name: "receiver-account-stellar-memo", + Usage: "Integration test receiver stellar memo", + OptType: types.String, + ConfigKey: &integrationTestsOpts.ReceiverAccountStellarMemo, + Required: false, + }, + { + Name: "sep10-signing-public-key", + Usage: "Anchor platform SEP10 signing public key", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPublicKey, + ConfigKey: &integrationTestsOpts.Sep10SigningPublicKey, + Required: true, + }, + { + Name: "disbursement-csv-file-name", + Usage: "File name of the integration test disbursement file.", + OptType: types.String, + ConfigKey: &integrationTestsOpts.DisbursementCSVFileName, + Required: true, + }, + { + Name: "disbursement-csv-file-path", + Usage: "File path of the integration test disbursement file.", + OptType: types.String, + ConfigKey: &integrationTestsOpts.DisbursementCSVFilePath, + Required: true, + }, + { + Name: "server-api-base-url", + Usage: "The Base URL of the server API of the SDP.", + OptType: types.String, + ConfigKey: &integrationTestsOpts.ServerApiBaseURL, + Required: true, + }, + { + Name: "anchor-platform-base-sep-url", + Usage: "The Base URL of the sep server of the anchor platform. This is the base URL where the Anchor Platform " + + "exposes its public API that is meant to be reached by a client application, such as the stellar.toml file.", + OptType: types.String, + ConfigKey: &integrationTestsOpts.AnchorPlatformBaseSepURL, + Required: true, + }, + { + Name: "recaptcha-site-key", + Usage: "The Google reCAPTCHA v2 - I'm not a robot site key.", + OptType: types.String, + ConfigKey: &integrationTestsOpts.RecaptchaSiteKey, + FlagDefault: "6LeIxAcTAAAAAJcZVRqyHh71UMIEGNQ_MXjiZKhI", + Required: true, + }, + } + + startIntegrationTestsCmd := &cobra.Command{ + Use: "start", + Short: "Run the e2e tests of the sdp application", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + + err := c.Service.StartIntegrationTests(ctx, *integrationTestsOpts) + if err != nil { + log.Fatalf("Error starting integration tests: %s", err.Error()) + } + }, + } + + err := configOpts.Init(startIntegrationTestsCmd) + if err != nil { + log.Fatalf("Error initializing startIntegrationTestsCmd: %s", err.Error()) + } + + return startIntegrationTestsCmd +} + +func (c *IntegrationTestsCommand) CreateIntegrationTestsDataCommand(integrationTestsOpts *integrationtests.IntegrationTestsOpts) *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "wallet-homepage", + Usage: "Wallet homepage to be used in integration tests", + OptType: types.String, + ConfigKey: &integrationTestsOpts.WalletHomepage, + FlagDefault: "https://www.test_wallet.com", + Required: true, + }, + { + Name: "wallet-deeplink", + Usage: "Wallet deeplink to be used in integration tests", + OptType: types.String, + ConfigKey: &integrationTestsOpts.WalletDeepLink, + FlagDefault: "test_wallet://", + Required: true, + }, + } + + createIntegrationTestsDataCmd := &cobra.Command{ + Use: "create-data", + Short: "Create integration tests data.", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + + err := c.Service.CreateTestData(ctx, *integrationTestsOpts) + if err != nil { + log.Fatalf("Error creating integration tests data: %s", err.Error()) + } + }, + } + + err := configOpts.Init(createIntegrationTestsDataCmd) + if err != nil { + log.Fatalf("Error initializing createIntegrationTestsDataCmd: %s", err.Error()) + } + + return createIntegrationTestsDataCmd +} diff --git a/cmd/integration_tests_test.go b/cmd/integration_tests_test.go new file mode 100644 index 000000000..2e6cc965a --- /dev/null +++ b/cmd/integration_tests_test.go @@ -0,0 +1,213 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "testing" + + "github.com/spf13/cobra" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/integrationtests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockIntegrationTests struct { + mock.Mock +} + +// Making sure that mockServer implements ServerServiceInterface +var _ integrationtests.IntegrationTestsInterface = (*mockIntegrationTests)(nil) + +func (m *mockIntegrationTests) StartIntegrationTests(ctx context.Context, opts integrationtests.IntegrationTestsOpts) error { + return m.Called(ctx, opts).Error(0) +} + +func (m *mockIntegrationTests) CreateTestData(ctx context.Context, opts integrationtests.IntegrationTestsOpts) error { + return m.Called(ctx, opts).Error(0) +} + +func Test_IntegrationTestsCommand_Command(t *testing.T) { + dbt := dbtest.Open(t) + + command := &IntegrationTestsCommand{} + + root := rootCmd() + cmd := command.Command() + root.AddCommand(cmd) + + t.Setenv("DISBURSED_ASSET_CODE", "USDC") + t.Setenv("DISBURSED_ASSET_ISSUER", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + t.Setenv("WALLET_NAME", "walletTest") + t.Setenv("WALLET_HOMEPAGE", "https://www.test_wallet.com") + t.Setenv("WALLET_DEEPLINK", "test_wallet://") + + root.SetArgs([]string{ + "integration-tests", + "create-data", + "--database-url", + dbt.DSN, + }) + err := cmd.Execute() + require.NoError(t, err) +} + +func Test_IntegrationTestsCommand_StartIntegrationTestsCommand(t *testing.T) { + serviceMock := &mockIntegrationTests{} + command := &IntegrationTestsCommand{Service: serviceMock} + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + integrationTestsOpts := &integrationtests.IntegrationTestsOpts{ + DatabaseDSN: "randomDatabaseDSN", + UserEmail: "mockemail@test.com", + UserPassword: "mockPassword123!", + DisbursedAssetCode: "USDC", + DisbursetAssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + WalletName: "walletTest", + DisbursementCSVFilePath: "mockPath", + DisbursementCSVFileName: "file.csv", + ReceiverAccountPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + ReceiverAccountPrivateKey: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + ReceiverAccountStellarMemo: "memo", + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + RecaptchaSiteKey: "reCAPTCHASiteKey", + AnchorPlatformBaseSepURL: "localhost:8080", + ServerApiBaseURL: "localhost:8000", + } + + cmd := command.StartIntegrationTestsCommand(integrationTestsOpts) + parentCmdMock.AddCommand(cmd) + + t.Setenv("DATABASE_URL", "randomDatabaseDSN") + t.Setenv("USER_EMAIL", "mockemail@test.com") + t.Setenv("USER_PASSWORD", "mockPassword123!") + t.Setenv("DISBURSED_ASSET_CODE", "USDC") + t.Setenv("DISBURSED_ASSET_ISSUER", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + t.Setenv("WALLET_NAME", "walletTest") + t.Setenv("DISBURSEMENT_CSV_FILE_PATH", "mockPath") + t.Setenv("DISBURSEMENT_CSV_FILE_NAME", "file.csv") + t.Setenv("RECEIVER_ACCOUNT_PUBLIC_KEY", "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA") + t.Setenv("RECEIVER_ACCOUNT_PRIVATE_KEY", "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5") + t.Setenv("RECEIVER_ACCOUNT_STELLAR_MEMO", "memo") + t.Setenv("SEP10_SIGNING_PUBLIC_KEY", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S") + t.Setenv("RECAPTCHA_SITE_KEY", "reCAPTCHASiteKey") + t.Setenv("ANCHOR_PLATFORM_BASE_SEP_URL", "localhost:8080") + t.Setenv("SERVER_API_BASE_URL", "localhost:8000") + + parentCmdMock.SetArgs([]string{ + "start", + }) + + t.Run("exit with status 1 when IntegrationTestsService fails", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + serviceMock. + On("StartIntegrationServe", context.Background(), *integrationTestsOpts). + Return(errors.New("unexpected error")) + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executes the start integration tests command successfully", func(t *testing.T) { + serviceMock. + On("StartIntegrationTests", context.Background(), *integrationTestsOpts). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + serviceMock.AssertExpectations(t) +} + +func Test_IntegrationTestsCommand_CreateIntegrationTestsDataCommand(t *testing.T) { + serviceMock := &mockIntegrationTests{} + command := &IntegrationTestsCommand{Service: serviceMock} + + parentCmdMock := &cobra.Command{ + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + } + + integrationTestsOpts := &integrationtests.IntegrationTestsOpts{ + DatabaseDSN: "randomDatabaseDSN", + DisbursedAssetCode: "USDC", + DisbursetAssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + WalletName: "walletTest", + WalletHomepage: "https://www.test_wallet.com", + WalletDeepLink: "test_wallet://", + } + + cmd := command.CreateIntegrationTestsDataCommand(integrationTestsOpts) + parentCmdMock.AddCommand(cmd) + + t.Setenv("DATABASE_URL", "randomDatabaseDSN") + t.Setenv("DISBURSED_ASSET_CODE", "USDC") + t.Setenv("DISBURSED_ASSET_ISSUER", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + t.Setenv("WALLET_NAME", "walletTest") + t.Setenv("WALLET_HOMEPAGE", "https://www.test_wallet.com") + t.Setenv("WALLET_DEEPLINK", "test_wallet://") + + parentCmdMock.SetArgs([]string{ + "create-data", + }) + + t.Run("exit with status 1 when IntegrationTestsService fails", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + serviceMock. + On("CreateTestData", context.Background(), *integrationTestsOpts). + Return(errors.New("unexpected error")) + + err := parentCmdMock.Execute() + require.NoError(t, err) + + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("executes the create integration tests data command successfully", func(t *testing.T) { + serviceMock. + On("CreateTestData", context.Background(), *integrationTestsOpts). + Return(nil) + + err := parentCmdMock.Execute() + require.NoError(t, err) + }) + + serviceMock.AssertExpectations(t) +} diff --git a/cmd/message.go b/cmd/message.go new file mode 100644 index 000000000..e0a18e887 --- /dev/null +++ b/cmd/message.go @@ -0,0 +1,146 @@ +package cmd + +import ( + "fmt" + "go/types" + + "github.com/spf13/cobra" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" +) + +type MessageCommand struct{} + +type MessengerServiceInterface interface { + GetClient(opts message.MessengerOptions) (message.MessengerClient, error) + SendMessage(opts message.MessengerOptions, message message.Message) error +} + +type MessengerService struct{} + +func (m *MessengerService) GetClient(opts message.MessengerOptions) (message.MessengerClient, error) { + return message.GetClient(opts) +} + +func (m *MessengerService) SendMessage(opts message.MessengerOptions, message message.Message) error { + messengerClient, err := m.GetClient(opts) + if err != nil { + return fmt.Errorf("getting messenger client: %w", err) + } + + return messengerClient.SendMessage(message) +} + +func (s *MessageCommand) Command(messengerService MessengerServiceInterface) *cobra.Command { + opts := message.MessengerOptions{} + messageCmdConfigOpts := config.ConfigOptions{ + // message sender type + { + Name: "message-sender-type", + Usage: `Message Sender Type. Options: "TWILIO_SMS", "AWS_SMS", "AWS_EMAIL", "DRY_RUN"`, + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionMessengerType, + ConfigKey: &opts.MessengerType, + Required: true, + }, + } + messageCmdConfigOpts = append(messageCmdConfigOpts, cmdUtils.TwilioConfigOptions(&opts)...) + messageCmdConfigOpts = append(messageCmdConfigOpts, cmdUtils.AWSConfigOptions(&opts)...) + + messageCmd := &cobra.Command{ + Use: "message", + Short: "Messenger related commands", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + // Inject dependencies: + opts.Environment = globalOptions.environment + + // Validate & ingest input parameters + messageCmdConfigOpts.Require() + err := messageCmdConfigOpts.SetValues() + if err != nil { + log.Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, _ []string) { + _, err := messengerService.GetClient(opts) + if err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + + log.Infof("πŸŽ‰ Successfully mounted messenger client for type %s", opts.MessengerType) + }, + } + err := messageCmdConfigOpts.Init(messageCmd) + if err != nil { + log.Fatalf("Error initializing messageCmd config option: %s", err.Error()) + } + + sendMessageCmd := s.sendMessageCommand(messengerService, &opts) + messageCmd.AddCommand(sendMessageCmd) + + return messageCmd +} + +func (s *MessageCommand) sendMessageCommand(messengerService MessengerServiceInterface, messageOptions *message.MessengerOptions) *cobra.Command { + msg := message.Message{} + // CLI arguments to send a message + sendMessageCmdConfigOpts := config.ConfigOptions{ + { + Name: "phone-number", + Usage: "The phone number to send the message to, in E.164. Mandatory if sending an SMS", + OptType: types.String, + ConfigKey: &msg.ToPhoneNumber, + Required: false, + }, + { + Name: "email", + Usage: "The email to send the message to. Mandatory if sending an email.", + OptType: types.String, + ConfigKey: &msg.ToEmail, + Required: false, + }, + { + Name: "title", + Usage: "The title to be set in the email. Mandatory if sending an email.", + OptType: types.String, + ConfigKey: &msg.Title, + Required: false, + }, + { + Name: "message", + Usage: "The text of the message to be sent", + OptType: types.String, + ConfigKey: &msg.Message, + Required: true, + }, + } + sendMessageCmd := &cobra.Command{ + Use: "send", + Short: "Send a message", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + sendMessageCmdConfigOpts.Require() + err := sendMessageCmdConfigOpts.SetValues() + if err != nil { + log.Fatalf("Error setting values of config options: %s", err.Error()) + } + }, + Run: func(_ *cobra.Command, _ []string) { + err := messengerService.SendMessage(*messageOptions, msg) + if err != nil { + log.Fatalf("Error sending message: %s", err.Error()) + } + }, + } + err := sendMessageCmdConfigOpts.Init(sendMessageCmd) + if err != nil { + log.Fatalf("Error initializing a sendMessageCmd option: %s", err.Error()) + } + + return sendMessageCmd +} diff --git a/cmd/message_test.go b/cmd/message_test.go new file mode 100644 index 000000000..26be82abe --- /dev/null +++ b/cmd/message_test.go @@ -0,0 +1,119 @@ +package cmd + +import ( + "bytes" + "testing" + + "github.com/spf13/cobra" + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockMessengerService struct { + mock.Mock +} + +func (m *mockMessengerService) GetClient(opts message.MessengerOptions) (message.MessengerClient, error) { + args := m.Called(opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(message.MessengerClient), args.Error(1) +} + +func (m *mockMessengerService) SendMessage(opts message.MessengerOptions, message message.Message) error { + return m.Called(opts, message).Error(0) +} + +func Test_message_help(t *testing.T) { + // setup + var out bytes.Buffer + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + messageCmdFound := false + for _, cmd := range rootCmd.Commands() { + if cmd.Use == "message" { + messageCmdFound = true + } + } + require.True(t, messageCmdFound, "message command not found") + rootCmd.SetArgs([]string{"message", "--help"}) + rootCmd.SetOut(&out) + + // test + err := rootCmd.Execute() + require.NoError(t, err) + + // assert + assert.Contains(t, out.String(), "stellar-disbursement-platform message [flags]", "should have printed help message for message command") +} + +func Test_message_GetClient_wasCalled(t *testing.T) { + cmdUtils.ClearTestEnvironment(t) + + mMessageService := mockMessengerService{} + wantMessageOptions := message.MessengerOptions{ + MessengerType: message.MessengerTypeTwilioSMS, + Environment: "development", + } + mMessageService.On("GetClient", wantMessageOptions).Return(nil, nil).Once() + + // setup + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + var commandToRemove *cobra.Command + commandToAdd := (&MessageCommand{}).Command(&mMessageService) + for _, cmd := range rootCmd.Commands() { + if cmd.Use == "message" { + commandToRemove = cmd + } + } + require.NotNil(t, commandToRemove, "message command not found") + rootCmd.RemoveCommand(commandToRemove) + rootCmd.AddCommand(commandToAdd) + rootCmd.SetArgs([]string{"message", "--message-sender-type", "twilio_sms"}) + + // test + err := rootCmd.Execute() + require.NoError(t, err) + + // assert + mMessageService.AssertExpectations(t) +} + +func Test_message_send_SendMessage_wasCalled(t *testing.T) { + cmdUtils.ClearTestEnvironment(t) + + mMessageService := mockMessengerService{} + wantMessageOptions := message.MessengerOptions{ + MessengerType: message.MessengerTypeTwilioSMS, + Environment: "development", + } + wantMessage := message.Message{ + ToPhoneNumber: "+41555511111", + Message: "hello world", + } + mMessageService.On("SendMessage", wantMessageOptions, wantMessage).Return(nil).Once() + + // setup + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + var commandToRemove *cobra.Command + commandToAdd := (&MessageCommand{}).Command(&mMessageService) + for _, cmd := range rootCmd.Commands() { + if cmd.Use == "message" { + commandToRemove = cmd + } + } + require.NotNil(t, commandToRemove, "message command not found") + rootCmd.RemoveCommand(commandToRemove) + rootCmd.AddCommand(commandToAdd) + rootCmd.SetArgs([]string{"message", "send", "--message-sender-type", "twilio_SMS", "--phone-number", "+41555511111", "--message", "hello world"}) + + // test + err := rootCmd.Execute() + require.NoError(t, err) + + // assert + mMessageService.AssertExpectations(t) +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 000000000..a5adc1800 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,141 @@ +package cmd + +import ( + "go/types" + + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/stellar/go/network" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" +) + +type globalOptionsType struct { + logLevel logrus.Level + sentryDSN string + environment string + version string + gitCommit string + databaseURL string + baseURL string + networkPassphrase string +} + +// populateConfigOptions populates the CrastTrackerOptions from the global options. +func (g globalOptionsType) populateCrashTrackerOptions(crashTrackerOptions *crashtracker.CrashTrackerOptions) { + if crashTrackerOptions.CrashTrackerType == crashtracker.CrashTrackerTypeSentry { + crashTrackerOptions.SentryDSN = g.sentryDSN + } + crashTrackerOptions.Environment = g.environment + crashTrackerOptions.GitCommit = g.gitCommit +} + +// globalOptions is a variable that holds the global CLI options that can be +// applied to any command or subcommand. +var globalOptions globalOptionsType + +const dbConfigOptionFlagName = "database-url" + +func rootCmd() *cobra.Command { + configOpts := config.ConfigOptions{ + { + Name: "log-level", + Usage: `The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC".`, + OptType: types.String, + FlagDefault: "TRACE", + ConfigKey: &globalOptions.logLevel, + CustomSetValue: cmdUtils.SetConfigOptionLogLevel, + Required: true, + }, + { + Name: "sentry-dsn", + Usage: "The DSN (client key) of the Sentry project. If not provided, Sentry will not be used.", + OptType: types.String, + ConfigKey: &globalOptions.sentryDSN, + Required: false, + }, + { + Name: "environment", + Usage: `The environment where the application is running. Example: "development", "staging", "production".`, + OptType: types.String, + FlagDefault: "development", + ConfigKey: &globalOptions.environment, + Required: true, + }, + { + Name: dbConfigOptionFlagName, + Usage: `Postgres DB URL`, + OptType: types.String, + FlagDefault: "postgres://localhost:5432/sdp?sslmode=disable", + ConfigKey: &globalOptions.databaseURL, + Required: true, + }, + { + Name: "base-url", + Usage: "The SDP UI base URL.", + OptType: types.String, + ConfigKey: &globalOptions.baseURL, + FlagDefault: "http://localhost:8000", + Required: true, + }, + { + Name: "network-passphrase", + Usage: "The Stellar network passphrase", + OptType: types.String, + ConfigKey: &globalOptions.networkPassphrase, + FlagDefault: network.TestNetworkPassphrase, + Required: true, + }, + } + + rootCmd := &cobra.Command{ + Use: "stellar-disbursement-platform", + Short: "Stellar Disbursement Platform", + Long: "The Stellar Disbursement Platform (SDP) enables organizations to disburse bulk payments to recipients using Stellar.", + Version: globalOptions.version, + PersistentPreRun: func(cmd *cobra.Command, _ []string) { + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Fatalf("Error setting values of config options: %s", err.Error()) + } + log.Info("Version: ", globalOptions.version) + log.Info("GitCommit: ", globalOptions.gitCommit) + }, + Run: func(cmd *cobra.Command, args []string) { + err := cmd.Help() + if err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + }, + } + + err := configOpts.Init(rootCmd) + if err != nil { + log.Fatalf("Error initializing a config option: %s", err.Error()) + } + + return rootCmd +} + +// SetupCLI sets up the CLI and returns the root command with the subcommands +// attached. +func SetupCLI(version, gitCommit string) *cobra.Command { + globalOptions.version = version + globalOptions.gitCommit = gitCommit + rootCmd := rootCmd() + + // Add subcommands + rootCmd.AddCommand((&ServeCommand{}).Command(&ServerService{}, &monitor.MonitorService{})) + rootCmd.AddCommand((&DatabaseCommand{}).Command()) + rootCmd.AddCommand((&MessageCommand{}).Command(&MessengerService{})) + rootCmd.AddCommand((&TxSubmitterCommand{}).Command(&TxSubmitterService{}, &monitor.MonitorService{})) + rootCmd.AddCommand((&ChannelAccountsCommand{}).Command()) + rootCmd.AddCommand((&IntegrationTestsCommand{}).Command()) + rootCmd.AddCommand((&AuthCommand{}).Command()) + + return rootCmd +} diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 000000000..5df68cc66 --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "bytes" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stretchr/testify/assert" +) + +func Test_globalOptions_populateCrashTrackerOptions(t *testing.T) { + globalOptions := globalOptionsType{ + environment: "test", + gitCommit: "1234567890abcdef", + sentryDSN: "test-sentry-dsn", + } + + t.Run("CrashTrackerType is not Sentry", func(t *testing.T) { + crashTrackerOptions := crashtracker.CrashTrackerOptions{} + globalOptions.populateCrashTrackerOptions(&crashTrackerOptions) + + wantCrashTrackerOptions := crashtracker.CrashTrackerOptions{ + Environment: "test", + GitCommit: "1234567890abcdef", + } + assert.Equal(t, wantCrashTrackerOptions, crashTrackerOptions) + }) + + t.Run("CrashTrackerType is Sentry", func(t *testing.T) { + crashTrackerOptions := crashtracker.CrashTrackerOptions{ + CrashTrackerType: crashtracker.CrashTrackerTypeSentry, + } + globalOptions.populateCrashTrackerOptions(&crashTrackerOptions) + + wantCrashTrackerOptions := crashtracker.CrashTrackerOptions{ + Environment: "test", + GitCommit: "1234567890abcdef", + SentryDSN: "test-sentry-dsn", + CrashTrackerType: crashtracker.CrashTrackerTypeSentry, + } + assert.Equal(t, wantCrashTrackerOptions, crashTrackerOptions) + }) +} + +func Test_noArgsAndHelpHaveSameResultAndDoDontPanic(t *testing.T) { + cmdArgsTestCases := [][]string{ + {"--help"}, + {}, + } + + for i, cmdArgs := range cmdArgsTestCases { + // setup + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + rootCmd.SetArgs(cmdArgs) + var out bytes.Buffer + rootCmd.SetOut(&out) + + // test + err := rootCmd.Execute() + assert.NoErrorf(t, err, "test case %d returned an error", i) + + // assert printed text + assert.Containsf(t, out.String(), "Use \"stellar-disbursement-platform [command] --help\" for more information about a command.", "test case %d did not print help message as expected", i) + } +} diff --git a/cmd/serve.go b/cmd/serve.go new file mode 100644 index 000000000..0762f5f18 --- /dev/null +++ b/cmd/serve.go @@ -0,0 +1,399 @@ +package cmd + +import ( + "context" + "fmt" + "go/types" + + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + di "github.com/stellar/stellar-disbursement-platform-backend/internal/dependencyinjection" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/scheduler" + "github.com/stellar/stellar-disbursement-platform-backend/internal/scheduler/jobs" + + "github.com/spf13/cobra" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve" +) + +type ServeCommand struct{} + +type ServerServiceInterface interface { + StartServe(opts serve.ServeOptions, httpServer serve.HTTPServerInterface) + StartMetricsServe(opts serve.MetricsServeOptions, httpServer serve.HTTPServerInterface) + GetSchedulerJobRegistrars(ctx context.Context, serveOpts serve.ServeOptions, schedulerOptions scheduler.SchedulerOptions) ([]scheduler.SchedulerJobRegisterOption, error) +} + +type ServerService struct{} + +// Making sure that ServerService implements ServerServiceInterface +var _ ServerServiceInterface = (*ServerService)(nil) + +func (s *ServerService) StartServe(opts serve.ServeOptions, httpServer serve.HTTPServerInterface) { + err := serve.Serve(opts, httpServer) + if err != nil { + log.Fatalf("Error starting server: %s", err.Error()) + } +} + +func (s *ServerService) StartMetricsServe(opts serve.MetricsServeOptions, httpServer serve.HTTPServerInterface) { + err := serve.MetricsServe(opts, httpServer) + if err != nil { + log.Fatalf("Error starting metrics server: %s", err.Error()) + } +} + +func (s *ServerService) GetSchedulerJobRegistrars(ctx context.Context, serveOpts serve.ServeOptions, schedulerOptions scheduler.SchedulerOptions) ([]scheduler.SchedulerJobRegisterOption, error) { + // TODO: inject these in the server options, to do the Dependency Injection properly. + dbConnectionPool, err := db.OpenDBConnectionPool(globalOptions.databaseURL) + if err != nil { + log.Ctx(ctx).Fatalf("error getting DB connection in Job Scheduler: %s", err.Error()) + } + models, err := data.NewModels(dbConnectionPool) + if err != nil { + log.Ctx(ctx).Fatalf("error creating models in Job Scheduler: %s", err.Error()) + } + + return []scheduler.SchedulerJobRegisterOption{ + scheduler.WithPaymentsProcessorJobOption(models), + scheduler.WithTSSMonitorJobOption(models), + scheduler.WithSendReceiverWalletsSMSInvitationJobOption(jobs.SendReceiverWalletsSMSInvitationJobOptions{ + AnchorPlatformBaseSepURL: serveOpts.AnchorPlatformBaseSepURL, + Models: models, + MessengerClient: serveOpts.SMSMessengerClient, + MinDaysBetweenRetries: schedulerOptions.MinDaysBetweenRetries, + MaxRetries: schedulerOptions.MaxRetries, + Sep10SigningPrivateKey: serveOpts.Sep10SigningPrivateKey, + CrashTrackerClient: serveOpts.CrashTrackerClient.Clone(), + }), + }, nil +} + +func (c *ServeCommand) Command(serverService ServerServiceInterface, monitorService monitor.MonitorServiceInterface) *cobra.Command { + serveOpts := serve.ServeOptions{} + metricsServeOpts := serve.MetricsServeOptions{} + schedulerOptions := scheduler.SchedulerOptions{} + crashTrackerOptions := crashtracker.CrashTrackerOptions{} + + configOpts := config.ConfigOptions{ + { + Name: "port", + Usage: "Port where the server will be listening on", + OptType: types.Int, + ConfigKey: &serveOpts.Port, + FlagDefault: 8000, + Required: true, + }, + { + Name: "metrics-type", + Usage: `Metric monitor type. Options: "PROMETHEUS"`, + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionMetricType, + ConfigKey: &metricsServeOpts.MetricType, + FlagDefault: "PROMETHEUS", + Required: true, + }, + { + Name: "metrics-port", + Usage: "Port where the metrics server will be listening on", + OptType: types.Int, + ConfigKey: &metricsServeOpts.Port, + FlagDefault: 8002, + Required: true, + }, + { + Name: "crash-tracker-type", + Usage: `Crash tracker type. Options: "SENTRY", "DRY_RUN"`, + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionCrashTrackerType, + ConfigKey: &crashTrackerOptions.CrashTrackerType, + FlagDefault: "DRY_RUN", + Required: true, + }, + { + Name: "ec256-public-key", + Usage: "The EC256 Public Key. This key is used to validate the token signature", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionEC256PublicKey, + ConfigKey: &serveOpts.EC256PublicKey, + Required: true, + }, + { + Name: "ec256-private-key", + Usage: "The EC256 Private Key. This key is used to sign the authentication token", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionEC256PrivateKey, + ConfigKey: &serveOpts.EC256PrivateKey, + Required: true, + }, + { + Name: "cors-allowed-origins", + Usage: `Cors URLs that are allowed to access the endpoints, separated by ","`, + OptType: types.String, + CustomSetValue: cmdUtils.SetCorsAllowedOrigins, + ConfigKey: &serveOpts.CorsAllowedOrigins, + Required: true, + }, + { + Name: "sep24-jwt-secret", + Usage: `The JWT secret that's used by the Anchor Platform to sign the SEP-24 JWT token`, + OptType: types.String, + ConfigKey: &serveOpts.SEP24JWTSecret, + Required: true, + }, + { + Name: "sep10-signing-public-key", + Usage: "The public key of the Stellar account that signs the SEP-10 transactions. It's also used to sign URLs.", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPublicKey, + ConfigKey: &serveOpts.Sep10SigningPublicKey, + Required: true, + }, + { + Name: "sep10-signing-private-key", + Usage: "The private key of the Stellar account that signs the SEP-10 transactions. It's also used to sign URLs.", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &serveOpts.Sep10SigningPrivateKey, + Required: true, + }, + { + Name: "anchor-platform-base-platform-url", + Usage: "The Base URL of the platform server of the anchor platform. This is the base URL where the Anchor Platform " + + "exposes its private API that is meant to be reached only by the SDP server, such as the PATCH /sep24/transactions endpoint.", + OptType: types.String, + ConfigKey: &serveOpts.AnchorPlatformBasePlatformURL, + Required: true, + }, + { + Name: "anchor-platform-base-sep-url", + Usage: "The Base URL of the sep server of the anchor platform. This is the base URL where the Anchor Platform " + + "exposes its public API that is meant to be reached by a client application, such as the stellar.toml file.", + OptType: types.String, + ConfigKey: &serveOpts.AnchorPlatformBaseSepURL, + Required: true, + }, + { + Name: "anchor-platform-outgoing-jwt-secret", + Usage: "The JWT secret used to create a JWT token used to send requests to the anchor platform.", + OptType: types.String, + ConfigKey: &serveOpts.AnchorPlatformOutgoingJWTSecret, + Required: false, + }, + { + Name: "reset-token-expiration-hours", + Usage: "The expiration time in hours of the Reset Token", + OptType: types.Int, + ConfigKey: &serveOpts.ResetTokenExpirationHours, + FlagDefault: 24, + Required: true, + }, + { + Name: "min-days-between-retries", + Usage: "The minimum amount of days that the invitation SMS was sent to the Receiver Wallets before we send the invitation again.", + OptType: types.Int, + ConfigKey: &schedulerOptions.MinDaysBetweenRetries, + FlagDefault: 7, + Required: true, + }, + { + Name: "max-retries", + Usage: "The maximum amount of tries to send the SMS invitation to the Receiver Wallets.", + OptType: types.Int, + ConfigKey: &schedulerOptions.MaxRetries, + FlagDefault: 3, + Required: true, + }, + { + Name: "distribution-public-key", + Usage: "The public key of the Stellar distribution account that sends the Stellar payments.", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPublicKey, + ConfigKey: &serveOpts.DistributionPublicKey, + Required: true, + }, + { + Name: "distribution-seed", + Usage: "The private key of the Stellar account used to disburse funds", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &serveOpts.DistributionSeed, + Required: true, + }, + { + Name: "recaptcha-site-key", + Usage: "The Google 'reCAPTCHA v2 - I'm not a robot' site key.", + OptType: types.String, + ConfigKey: &serveOpts.ReCAPTCHASiteKey, + Required: true, + }, + { + Name: "recaptcha-site-secret-key", + Usage: "The Google 'reCAPTCHA v2 - I'm not a robot' site SECRET key.", + OptType: types.String, + ConfigKey: &serveOpts.ReCAPTCHASiteSecretKey, + Required: true, + }, + { + Name: "sdp-ui-base-url", + Usage: "The SDP UI Base URL.", + OptType: types.String, + ConfigKey: &serveOpts.UIBaseURL, + FlagDefault: "http://localhost:3000", + CustomSetValue: cmdUtils.SetConfigOptionURLString, + Required: true, + }, + { + Name: "enable-mfa", + Usage: "Enable MFA using email.", + OptType: types.Bool, + ConfigKey: &serveOpts.EnableMFA, + FlagDefault: true, + Required: false, + }, + { + Name: "enable-recaptcha", + Usage: "Enable ReCAPTCHA for login and forgot password.", + OptType: types.Bool, + ConfigKey: &serveOpts.EnableReCAPTCHA, + FlagDefault: true, + Required: false, + }, + { + Name: "horizon-url", + Usage: "Stellar Horizon URL.", + OptType: types.String, + ConfigKey: &serveOpts.HorizonURL, + FlagDefault: horizonclient.DefaultTestNetClient.HorizonURL, + Required: true, + }, + } + + messengerOptions := message.MessengerOptions{} + + // messenger config options: + configOpts = append(configOpts, cmdUtils.TwilioConfigOptions(&messengerOptions)...) + configOpts = append(configOpts, cmdUtils.AWSConfigOptions(&messengerOptions)...) + + // sms + smsOpts := di.SMSClientOptions{MessengerOptions: &messengerOptions} + configOpts = append(configOpts, + &config.ConfigOption{ + // message sender type + Name: "sms-sender-type", + Usage: fmt.Sprintf("SMS Sender Type. Options: %+v", message.MessengerType("").ValidSMSTypes()), + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionMessengerType, + ConfigKey: &smsOpts.SMSType, + FlagDefault: string(message.MessengerTypeDryRun), + Required: true, + }) + + // email + emailOpts := di.EmailClientOptions{MessengerOptions: &messengerOptions} + configOpts = append(configOpts, + &config.ConfigOption{ + // message sender type + Name: "email-sender-type", + Usage: fmt.Sprintf("Email Sender Type. Options: %+v", message.MessengerType("").ValidEmailTypes()), + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionMessengerType, + ConfigKey: &emailOpts.EmailType, + FlagDefault: string(message.MessengerTypeDryRun), + Required: true, + }) + + cmd := &cobra.Command{ + Use: "serve", + Short: "Serve the Stellar Disbursement Platform API", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Fatalf("Error setting values of config options: %s", err.Error()) + } + + // Initializing monitor service + metricOptions := monitor.MetricOptions{ + MetricType: metricsServeOpts.MetricType, + Environment: globalOptions.environment, + } + + err = monitorService.Start(metricOptions) + if err != nil { + log.Fatalf("Error creating monitor service: %s", err.Error()) + } + + // Inject crash tracker options dependencies + globalOptions.populateCrashTrackerOptions(&crashTrackerOptions) + + // Inject server dependencies + serveOpts.Environment = globalOptions.environment + serveOpts.GitCommit = globalOptions.gitCommit + serveOpts.DatabaseDSN = globalOptions.databaseURL + serveOpts.Version = globalOptions.version + serveOpts.MonitorService = monitorService + serveOpts.BaseURL = globalOptions.baseURL + serveOpts.NetworkPassphrase = globalOptions.networkPassphrase + + // Inject metrics server dependencies + metricsServeOpts.MonitorService = monitorService + metricsServeOpts.Environment = globalOptions.environment + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + + // Setup default Crash Tracker client + crashTrackerClient, err := di.NewCrashTracker(ctx, crashTrackerOptions) + if err != nil { + log.Ctx(ctx).Fatalf("error creating crash tracker client: %s", err.Error()) + } + serveOpts.CrashTrackerClient = crashTrackerClient + + // Setup default Email client + emailMessengerClient, err := di.NewEmailClient(emailOpts) + if err != nil { + log.Ctx(ctx).Fatalf("error creating email client: %s", err.Error()) + } + serveOpts.EmailMessengerClient = emailMessengerClient + + // Setup default SMS client + smsMessengerClient, err := di.NewSMSClient(smsOpts) + if err != nil { + log.Ctx(ctx).Fatalf("error creating SMS client: %s", err.Error()) + } + serveOpts.SMSMessengerClient = smsMessengerClient + + // Starting Scheduler Service (background job) + log.Ctx(ctx).Info("Starting Scheduler Service...") + schedulerJobRegistrats, err := serverService.GetSchedulerJobRegistrars(ctx, serveOpts, schedulerOptions) + if err != nil { + log.Ctx(ctx).Fatalf("Error getting scheduler job registrars: %s", err.Error()) + } + go scheduler.StartScheduler(crashTrackerClient.Clone(), schedulerJobRegistrats...) + + // Starting Metrics Server (background job) + log.Ctx(ctx).Info("Starting Metrics Server...") + go serverService.StartMetricsServe(metricsServeOpts, &serve.HTTPServer{}) + + // Starting Application Server + log.Ctx(ctx).Info("Starting Application Server...") + serverService.StartServe(serveOpts, &serve.HTTPServer{}) + }, + } + err := configOpts.Init(cmd) + if err != nil { + log.Fatalf("Error initializing a config option: %s", err.Error()) + } + + return cmd +} diff --git a/cmd/serve_test.go b/cmd/serve_test.go new file mode 100644 index 000000000..2aca9470a --- /dev/null +++ b/cmd/serve_test.go @@ -0,0 +1,194 @@ +package cmd + +import ( + "bytes" + "context" + "sync" + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/network" + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + di "github.com/stellar/stellar-disbursement-platform-backend/internal/dependencyinjection" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/scheduler" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockServer struct { + wg sync.WaitGroup + mock.Mock +} + +// Making sure that mockServer implements ServerServiceInterface +var _ ServerServiceInterface = (*mockServer)(nil) + +func (m *mockServer) StartServe(opts serve.ServeOptions, httpServer serve.HTTPServerInterface) { + m.Called(opts, httpServer) + m.wg.Wait() +} + +func (m *mockServer) StartMetricsServe(opts serve.MetricsServeOptions, httpServer serve.HTTPServerInterface) { + m.Called(opts, httpServer) + m.wg.Done() +} + +func (m *mockServer) GetSchedulerJobRegistrars(ctx context.Context, serveOpts serve.ServeOptions, schedulerOptions scheduler.SchedulerOptions) ([]scheduler.SchedulerJobRegisterOption, error) { + args := m.Called(ctx, serveOpts, schedulerOptions) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]scheduler.SchedulerJobRegisterOption), args.Error(1) +} + +func Test_serve_wasCalled(t *testing.T) { + // setup + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + serveCmdFound := false + + for _, cmd := range rootCmd.Commands() { + if cmd.Use == "serve" { + serveCmdFound = true + } + } + require.True(t, serveCmdFound, "serve command not found") + rootCmd.SetArgs([]string{"serve", "--help"}) + var out bytes.Buffer + rootCmd.SetOut(&out) + + // test + err := rootCmd.Execute() + require.NoError(t, err) + + // assert + assert.Contains(t, out.String(), "stellar-disbursement-platform serve [flags]", "should have printed help message for serve command") +} + +func Test_serve(t *testing.T) { + dbt := dbtest.Open(t) + randomDatabaseDSN := dbt.DSN + dbt.Close() + + cmdUtils.ClearTestEnvironment(t) + + ctx := context.Background() + + // mock metric service + mMonitorService := monitor.MockMonitorService{} + + serveOpts := serve.ServeOptions{ + Environment: "test", + GitCommit: "1234567890abcdef", + Port: 8000, + Version: "x.y.z", + MonitorService: &mMonitorService, + DatabaseDSN: randomDatabaseDSN, + EC256PublicKey: "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER88h7AiQyVDysRTxKvBB6CaiO/kS\ncvGyimApUE/12gFhNTRf37SE19CSCllKxstnVFOpLLWB7Qu5OJ0Wvcz3hg==\n-----END PUBLIC KEY-----", + EC256PrivateKey: "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIqI1MzMZIw2pQDLx\nJn0+FcNT/hNjwtn2TW43710JKZqhRANCAARHzyHsCJDJUPKxFPEq8EHoJqI7+RJy\n8bKKYClQT/XaAWE1NF/ftITX0JIKWUrGy2dUU6kstYHtC7k4nRa9zPeG\n-----END PRIVATE KEY-----", + CorsAllowedOrigins: []string{"*"}, + SEP24JWTSecret: "jwt_secret_1234567890", + BaseURL: "https://sdp.com", + UIBaseURL: "http://localhost:3000", + ResetTokenExpirationHours: 24, + NetworkPassphrase: network.TestNetworkPassphrase, + HorizonURL: horizonclient.DefaultTestNetClient.HorizonURL, + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + Sep10SigningPrivateKey: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + AnchorPlatformBaseSepURL: "localhost:8080", + AnchorPlatformBasePlatformURL: "localhost:8085", + DistributionPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + DistributionSeed: "SBHQEYSACD5DOK5I656NKLAMOHC6VT64ATOWWM2VJ3URGDGMVGNPG4ON", + ReCAPTCHASiteKey: "reCAPTCHASiteKey", + ReCAPTCHASiteSecretKey: "reCAPTCHASiteSecretKey", + EnableMFA: true, + EnableReCAPTCHA: true, + } + + crashTrackerClient, err := di.NewCrashTracker(ctx, crashtracker.CrashTrackerOptions{ + Environment: serveOpts.Environment, + GitCommit: serveOpts.GitCommit, + CrashTrackerType: "DRY_RUN", + }) + require.NoError(t, err) + serveOpts.CrashTrackerClient = crashTrackerClient + + messengerClient, err := di.NewEmailClient(di.EmailClientOptions{EmailType: message.MessengerTypeDryRun}) + require.NoError(t, err) + serveOpts.EmailMessengerClient = messengerClient + + smsMessengerClient, err := di.NewSMSClient(di.SMSClientOptions{SMSType: message.MessengerTypeDryRun}) + require.NoError(t, err) + serveOpts.SMSMessengerClient = smsMessengerClient + + metricOptions := monitor.MetricOptions{ + MetricType: monitor.MetricTypePrometheus, + Environment: "test", + } + mMonitorService.On("Start", metricOptions).Return(nil).Once() + + serveMetricOpts := serve.MetricsServeOptions{ + Port: 8002, + Environment: "test", + + MetricType: monitor.MetricTypePrometheus, + MonitorService: &mMonitorService, + } + + schedulerOptions := scheduler.SchedulerOptions{ + MinDaysBetweenRetries: 7, + MaxRetries: 3, + } + + // mock server + mServer := mockServer{} + mServer.On("StartMetricsServe", serveMetricOpts, mock.AnythingOfType("*serve.HTTPServer")).Once() + mServer.On("StartServe", serveOpts, mock.AnythingOfType("*serve.HTTPServer")).Once() + mServer. + On("GetSchedulerJobRegistrars", mock.AnythingOfType("*context.emptyCtx"), serveOpts, schedulerOptions). + Return([]scheduler.SchedulerJobRegisterOption{}, nil). + Once() + mServer.wg.Add(1) + + // SetupCLI and replace the serve command with one containing a mocked server + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + originalCommands := rootCmd.Commands() + rootCmd.ResetCommands() + serveCmdFound := false + for _, cmd := range originalCommands { + if cmd.Use == "serve" { + serveCmdFound = true + rootCmd.AddCommand((&ServeCommand{}).Command(&mServer, &mMonitorService)) + } else { + rootCmd.AddCommand(cmd) + } + } + require.True(t, serveCmdFound, "serve command not found") + + t.Setenv("DATABASE_URL", serveOpts.DatabaseDSN) + t.Setenv("EC256_PUBLIC_KEY", serveOpts.EC256PublicKey) + t.Setenv("EC256_PRIVATE_KEY", serveOpts.EC256PrivateKey) + t.Setenv("SEP24_JWT_SECRET", serveOpts.SEP24JWTSecret) + t.Setenv("SEP10_SIGNING_PUBLIC_KEY", serveOpts.Sep10SigningPublicKey) + t.Setenv("SEP10_SIGNING_PRIVATE_KEY", serveOpts.Sep10SigningPrivateKey) + t.Setenv("ANCHOR_PLATFORM_BASE_SEP_URL", serveOpts.AnchorPlatformBaseSepURL) + t.Setenv("ANCHOR_PLATFORM_BASE_PLATFORM_URL", serveOpts.AnchorPlatformBasePlatformURL) + t.Setenv("DISTRIBUTION_PUBLIC_KEY", serveOpts.DistributionPublicKey) + t.Setenv("DISTRIBUTION_SEED", serveOpts.DistributionSeed) + t.Setenv("BASE_URL", serveOpts.BaseURL) + t.Setenv("RECAPTCHA_SITE_KEY", serveOpts.ReCAPTCHASiteKey) + t.Setenv("RECAPTCHA_SITE_SECRET_KEY", serveOpts.ReCAPTCHASiteSecretKey) + t.Setenv("CORS_ALLOWED_ORIGINS", "*") + + // test & assert + rootCmd.SetArgs([]string{"--environment", "test", "serve", "--metrics-type", "PROMETHEUS"}) + err = rootCmd.Execute() + require.NoError(t, err) + mServer.AssertExpectations(t) + mMonitorService.AssertExpectations(t) +} diff --git a/cmd/transaction_submitter.go b/cmd/transaction_submitter.go new file mode 100644 index 000000000..1ab162bf2 --- /dev/null +++ b/cmd/transaction_submitter.go @@ -0,0 +1,196 @@ +package cmd + +import ( + "context" + "go/types" + "os" + "os/signal" + "syscall" + + "github.com/spf13/cobra" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + + cmdUtils "github.com/stellar/stellar-disbursement-platform-backend/cmd/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + di "github.com/stellar/stellar-disbursement-platform-backend/internal/dependencyinjection" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve" + txSub "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission" + tssUtils "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +type TxSubmitterCommand struct{} + +type TxSubmitterServiceInterface interface { + StartSubmitter(context.Context, txSub.SubmitterOptions) + StartMetricsServe(ctx context.Context, opts serve.MetricsServeOptions, httpServer serve.HTTPServerInterface, crashTrackerClient crashtracker.CrashTrackerClient) +} + +type TxSubmitterService struct{} + +// StartSubmitter starts the Transaction Submission Service +func (t *TxSubmitterService) StartSubmitter(ctx context.Context, opts txSub.SubmitterOptions) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + // Wait for a termination signal + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGTERM) + <-sig + + // Cancel the context to signal the submitterService to exit + cancel() + }() + + tssManager, err := txSub.NewManager(ctx, opts) + if err != nil { + opts.CrashTrackerClient.LogAndReportErrors(ctx, err, "Cannot start submitter service") + log.Fatalf("Error starting transaction submission service: %s", err.Error()) + } + + tssManager.ProcessTransactions(ctx) +} + +func (s *TxSubmitterService) StartMetricsServe(ctx context.Context, opts serve.MetricsServeOptions, httpServer serve.HTTPServerInterface, crashTrackerClient crashtracker.CrashTrackerClient) { + err := serve.MetricsServe(opts, httpServer) + if err != nil { + crashTrackerClient.LogAndReportErrors(ctx, err, "Cannot start metrics service") + log.Fatalf("Error starting metrics server: %s", err.Error()) + } +} + +func (c *TxSubmitterCommand) Command(submitterService TxSubmitterServiceInterface, monitorService monitor.MonitorServiceInterface) *cobra.Command { + submitterOpts := txSub.SubmitterOptions{} + metricsServeOpts := serve.MetricsServeOptions{} + crashTrackerOptions := crashtracker.CrashTrackerOptions{} + + configOpts := config.ConfigOptions{ + { + Name: "tss-metrics-port", + Usage: `Port where the metrics server will be listening on. Default: 9002"`, + OptType: types.Int, + ConfigKey: &metricsServeOpts.Port, + FlagDefault: 9002, + Required: true, + }, + { + Name: "tss-metrics-type", + Usage: `Metric monitor type. Options: "TSS_PROMETHEUS"`, + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionMetricType, + ConfigKey: &metricsServeOpts.MetricType, + FlagDefault: "TSS_PROMETHEUS", + Required: true, + }, + { + Name: "distribution-seed", + Usage: "The private key of the Stellar account used to disburse funds", + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionStellarPrivateKey, + ConfigKey: &submitterOpts.DistributionSeed, + Required: true, + }, + { + Name: "horizon-url", + Usage: "Horizon URL", + OptType: types.String, + ConfigKey: &submitterOpts.HorizonURL, + FlagDefault: horizonclient.DefaultTestNetClient.HorizonURL, + Required: true, + }, + { + Name: "num-channel-accounts", + Usage: "Number of channel accounts to utilize for transaction submission", + OptType: types.Int, + ConfigKey: &submitterOpts.NumChannelAccounts, + FlagDefault: 2, + Required: false, + }, + { + Name: "queue-polling-interval", + Usage: "Polling interval (seconds) to query the database for pending transactions to process", + OptType: types.Int, + ConfigKey: &submitterOpts.QueuePollingInterval, + FlagDefault: 6, + Required: true, + }, + { + Name: "max-base-fee", + Usage: "The max base fee for submitting a Stellar transaction", + OptType: types.Int, + ConfigKey: &submitterOpts.MaxBaseFee, + FlagDefault: txnbuild.MinBaseFee, + Required: true, + }, + { + Name: "crash-tracker-type", + Usage: `Crash tracker type. Options: "SENTRY", "DRY_RUN"`, + OptType: types.String, + CustomSetValue: cmdUtils.SetConfigOptionCrashTrackerType, + ConfigKey: &crashTrackerOptions.CrashTrackerType, + FlagDefault: "DRY_RUN", + Required: true, + }, + } + cmd := &cobra.Command{ + Use: "tss", + Short: "Run the Transaction Submission Service", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + ctx := cmd.Context() + + // Validate & ingest input parameters + configOpts.Require() + err := configOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("Error setting values of config options: %s", err.Error()) + } + + // Initializing monitor service + metricOptions := monitor.MetricOptions{ + MetricType: metricsServeOpts.MetricType, + Environment: globalOptions.environment, + } + + // Inject metrics dependencies + err = monitorService.Start(metricOptions) + if err != nil { + log.Ctx(ctx).Fatalf("Error creating monitor service: %s", err.Error()) + } + metricsServeOpts.MonitorService = monitorService + + // Inject server dependencies + submitterOpts.MonitorService = monitorService + submitterOpts.DatabaseDSN = globalOptions.databaseURL + submitterOpts.NetworkPassphrase = globalOptions.networkPassphrase + submitterOpts.PrivateKeyEncrypter = tssUtils.DefaultPrivateKeyEncrypter{} + + // Inject crash tracker options dependencies + globalOptions.populateCrashTrackerOptions(&crashTrackerOptions) + // Setup default Crash Tracker client + crashTrackerClient, err := di.NewCrashTracker(ctx, crashTrackerOptions) + if err != nil { + log.Ctx(ctx).Fatalf("error creating crash tracker client: %s", err.Error()) + } + submitterOpts.CrashTrackerClient = crashTrackerClient + }, + Run: func(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() + // Starting Metrics Server (background job) + go submitterService.StartMetricsServe(ctx, metricsServeOpts, &serve.HTTPServer{}, submitterOpts.CrashTrackerClient) + + // Start transaction submission service + submitterService.StartSubmitter(ctx, submitterOpts) + }, + } + err := configOpts.Init(cmd) + if err != nil { + log.Fatalf("Error initializing a config option: %s", err.Error()) + } + + return cmd +} diff --git a/cmd/transaction_submitter_test.go b/cmd/transaction_submitter_test.go new file mode 100644 index 000000000..57db054d3 --- /dev/null +++ b/cmd/transaction_submitter_test.go @@ -0,0 +1,135 @@ +package cmd + +import ( + "bytes" + "context" + "os" + "strings" + "sync" + "testing" + + "github.com/spf13/cobra" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve" + txSub "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission" + tssUtils "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockSubmitter struct { + mock.Mock + wg sync.WaitGroup +} + +func (t *mockSubmitter) StartSubmitter(ctx context.Context, opts txSub.SubmitterOptions) { + t.Called(ctx, opts) + t.wg.Wait() +} + +func (t *mockSubmitter) StartMock(opts txSub.SubmitterOptions) { + t.Called(opts) +} + +func (t *mockSubmitter) StartMetricsServe(ctx context.Context, opts serve.MetricsServeOptions, httpServer serve.HTTPServerInterface, crashTrackerClient crashtracker.CrashTrackerClient) { + t.Called(ctx, opts, httpServer, crashTrackerClient) + t.wg.Done() +} + +func Test_tss_help(t *testing.T) { + // setup + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + tssCmdFound := false + + for _, cmd := range rootCmd.Commands() { + if cmd.Use == "tss" { + tssCmdFound = true + } + } + require.True(t, tssCmdFound, "tss command not found") + rootCmd.SetArgs([]string{"tss", "--help"}) + var out bytes.Buffer + rootCmd.SetOut(&out) + + // test + err := rootCmd.Execute() + require.NoError(t, err) + + // assert + assert.Contains(t, out.String(), "stellar-disbursement-platform tss [flags]", "should have printed help message for tss command") +} + +func Test_tss(t *testing.T) { + for _, env := range os.Environ() { + key := env[:strings.Index(env, "=")] + t.Setenv(key, "") + } + + dryRunClient, err := crashtracker.NewDryRunClient() + require.NoError(t, err) + + mMonitorService := monitor.MockMonitorService{} + wantSubmitterOptions := txSub.SubmitterOptions{ + DatabaseDSN: "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable", + HorizonURL: "https://horizon-testnet.stellar.org", + DistributionSeed: "SBQ3ZNC2SE3FV43HZ2KW3FCXQMMIQ33LZB745KTMCHDS6PNQOVXMV5NC", + NetworkPassphrase: "Test SDF Network ; September 2015", + MaxBaseFee: 100, + NumChannelAccounts: 2, + QueuePollingInterval: 6, + MonitorService: &mMonitorService, + CrashTrackerClient: dryRunClient, + PrivateKeyEncrypter: tssUtils.DefaultPrivateKeyEncrypter{}, + } + + metricOptions := monitor.MetricOptions{ + MetricType: monitor.MetricTypeTSSPrometheus, + Environment: "test", + } + mMonitorService.On("Start", metricOptions).Return(nil).Once() + + serveMetricOpts := serve.MetricsServeOptions{ + Port: 9002, + MetricType: monitor.MetricTypeTSSPrometheus, + MonitorService: &mMonitorService, + } + + mTSS := mockSubmitter{} + rootCmd := SetupCLI("x.y.z", "1234567890abcdef") + + mTSS.On("StartMetricsServe", mock.Anything, serveMetricOpts, mock.AnythingOfType("*serve.HTTPServer"), dryRunClient).Once() + mTSS.On("StartSubmitter", mock.Anything, wantSubmitterOptions).Once() + mTSS.wg.Add(1) + // setup + var commandToRemove *cobra.Command + commandToAdd := (&TxSubmitterCommand{}).Command(&mTSS, &mMonitorService) + for _, cmd := range rootCmd.Commands() { + if cmd.Use == "tss" { + commandToRemove = cmd + } + } + require.NotNil(t, commandToRemove, "tss command not found") + rootCmd.RemoveCommand(commandToRemove) + rootCmd.AddCommand(commandToAdd) + rootCmd.SetArgs([]string{ + "tss", + "--environment", "test", + "--database-url", "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable", + "--distribution-seed", "SBQ3ZNC2SE3FV43HZ2KW3FCXQMMIQ33LZB745KTMCHDS6PNQOVXMV5NC", + "--horizon-url", "https://horizon-testnet.stellar.org", + "--network-passphrase", "Test SDF Network ; September 2015", + }) + + t.Setenv("DATABASE_URL", "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable") + + // test + err = rootCmd.Execute() + require.NoError(t, err) + + // assert + mTSS.AssertExpectations(t) + mMonitorService.AssertExpectations(t) +} diff --git a/cmd/utils/custom_set_value.go b/cmd/utils/custom_set_value.go new file mode 100644 index 000000000..ca82f28d7 --- /dev/null +++ b/cmd/utils/custom_set_value.go @@ -0,0 +1,205 @@ +package utils + +import ( + "fmt" + "net/url" + "strings" + + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stellar/go/keypair" + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +func SetConfigOptionMessengerType(co *config.ConfigOption) error { + senderType := viper.GetString(co.Name) + + messengerType, err := message.ParseMessengerType(senderType) + if err != nil { + return fmt.Errorf("couldn't parse messenger type: %w", err) + } + + *(co.ConfigKey.(*message.MessengerType)) = messengerType + return nil +} + +func SetConfigOptionMetricType(co *config.ConfigOption) error { + metricType := viper.GetString(co.Name) + + metricTypeParsed, err := monitor.ParseMetricType(metricType) + if err != nil { + return fmt.Errorf("couldn't parse metric type: %w", err) + } + + *(co.ConfigKey.(*monitor.MetricType)) = metricTypeParsed + return nil +} + +func SetConfigOptionCrashTrackerType(co *config.ConfigOption) error { + ctType := viper.GetString(co.Name) + + ctTypeParsed, err := crashtracker.ParseCrashTrackerType(ctType) + if err != nil { + return fmt.Errorf("couldn't parse crash tracker type: %w", err) + } + + *(co.ConfigKey.(*crashtracker.CrashTrackerType)) = ctTypeParsed + return nil +} + +func SetConfigOptionLogLevel(co *config.ConfigOption) error { + // parse string to logLevel object + logLevelStr := viper.GetString(co.Name) + logLevel, err := logrus.ParseLevel(logLevelStr) + if err != nil { + return fmt.Errorf("couldn't parse log level: %w", err) + } + + // update the configKey + key, ok := co.ConfigKey.(*logrus.Level) + if !ok { + return fmt.Errorf("configKey has an invalid type %T", co.ConfigKey) + } + *key = logLevel + + // Log for debugging + if config.IsExplicitlySet(co) { + log.Debugf("Setting log level to: %q", logLevel) + log.DefaultLogger.SetLevel(*key) + } else { + log.Debugf("Using default log level: %q", logLevel) + } + return nil +} + +// SetConfigOptionEC256PublicKey parses the config option incoming value and validates if it is a valid EC256PublicKey. +func SetConfigOptionEC256PublicKey(co *config.ConfigOption) error { + key, ok := co.ConfigKey.(*string) + if !ok { + return fmt.Errorf("not a valid EC256PublicKey: the expected type for this config key is a string, but got a %T instead", co.ConfigKey) + } + + publicKey := viper.GetString(co.Name) + + // We must remove the literal \n in case of the config options being set this way + publicKey = strings.Replace(publicKey, `\n`, "\n", -1) + + _, err := utils.ParseECDSAPublicKey(publicKey) + if err != nil { + return fmt.Errorf("parsing EC256PublicKey: %w", err) + } + + *key = publicKey + return nil +} + +// SetConfigOptionEC256PrivateKey parses the config option incoming value and validates if it is a valid EC256PrivateKey. +func SetConfigOptionEC256PrivateKey(co *config.ConfigOption) error { + key, ok := co.ConfigKey.(*string) + if !ok { + return fmt.Errorf("not a valid EC256PrivateKey: the expected type for this config key is a string, but got a %T instead", co.ConfigKey) + } + + privateKey := viper.GetString(co.Name) + + // We must remove the literal \n in case of the config options being set this way + privateKey = strings.Replace(privateKey, `\n`, "\n", -1) + + _, err := utils.ParseECDSAPrivateKey(privateKey) + if err != nil { + return fmt.Errorf("parsing EC256PrivateKey: %w", err) + } + + *key = privateKey + return nil +} + +func SetCorsAllowedOrigins(co *config.ConfigOption) error { + corsAllowedOriginsOptions := viper.GetString(co.Name) + + if corsAllowedOriginsOptions == "" { + return fmt.Errorf("cors allowed addresses cannot be empty") + } + + corsAllowedOrigins := strings.Split(corsAllowedOriginsOptions, ",") + + // validate addresses + for _, address := range corsAllowedOrigins { + _, err := url.ParseRequestURI(address) + if err != nil { + return fmt.Errorf("error parsing cors addresses: %w", err) + } + if address == "*" { + log.Warn(`The value "*" for the CORS Allowed Origins is too permissive and not recommended.`) + } + } + + key, ok := co.ConfigKey.(*[]string) + if !ok { + return fmt.Errorf("the expected type for this config key is a string slice, but got a %T instead", co.ConfigKey) + } + *key = corsAllowedOrigins + + return nil +} + +func SetConfigOptionStellarPublicKey(co *config.ConfigOption) error { + publicKey := viper.GetString(co.Name) + + kp, err := keypair.ParseAddress(publicKey) + if err != nil { + return fmt.Errorf("error validating public key: %w", err) + } + + key, ok := co.ConfigKey.(*string) + if !ok { + return fmt.Errorf("the expected type for this config key is a string, but got a %T instead", co.ConfigKey) + } + *key = kp.Address() + + return nil +} + +func SetConfigOptionStellarPrivateKey(co *config.ConfigOption) error { + privateKey := viper.GetString(co.Name) + + isValid := strkey.IsValidEd25519SecretSeed(privateKey) + if !isValid { + return fmt.Errorf("error validating private key: %q", utils.TruncateString(privateKey, 2)) + } + + key, ok := co.ConfigKey.(*string) + if !ok { + return fmt.Errorf("the expected type for this config key is a string, but got a %T instead", co.ConfigKey) + } + *key = privateKey + + return nil +} + +func SetConfigOptionURLString(co *config.ConfigOption) error { + u := viper.GetString(co.Name) + + if u == "" { + return fmt.Errorf("ui base url cannot be empty") + } + + _, err := url.ParseRequestURI(u) + if err != nil { + return fmt.Errorf("error parsing ui base url: %w", err) + } + + key, ok := co.ConfigKey.(*string) + if !ok { + return fmt.Errorf("the expected type for this config key is a string, but got a %T instead", co.ConfigKey) + } + *key = u + + return nil +} diff --git a/cmd/utils/custom_set_value_test.go b/cmd/utils/custom_set_value_test.go new file mode 100644 index 000000000..dc146da8b --- /dev/null +++ b/cmd/utils/custom_set_value_test.go @@ -0,0 +1,584 @@ +package utils + +import ( + "go/types" + "strings" + "testing" + + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// customSetterTestCase is a test case to test a custom_set_value function. +type customSetterTestCase[T any] struct { + name string + args []string + envValue string + wantErrContains string + wantResult T +} + +// customSetterTester tests a custom_set_value function, according with the customSetterTestCase provided. +func customSetterTester[T any](t *testing.T, tc customSetterTestCase[T], co config.ConfigOption) { + ClearTestEnvironment(t) + if tc.envValue != "" { + envName := strings.ToUpper(co.Name) + envName = strings.ReplaceAll(envName, "-", "_") + t.Setenv(envName, tc.envValue) + } + + // start the CLI command + testCmd := cobra.Command{ + RunE: func(cmd *cobra.Command, args []string) error { + co.Require() + return co.SetValue() + }, + } + // mock the command line output + buf := new(strings.Builder) + testCmd.SetOut(buf) + + // Initialize the command for the given option + err := co.Init(&testCmd) + require.NoError(t, err) + + // execute command line + if len(tc.args) > 0 { + testCmd.SetArgs(tc.args) + } + err = testCmd.Execute() + + // check the result + if tc.wantErrContains != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrContains) + } else { + assert.NoError(t, err) + } + + if !utils.IsEmpty(tc.wantResult) { + destPointer := utils.UnwrapInterfaceToPointer[T](co.ConfigKey) + assert.Equal(t, tc.wantResult, *destPointer) + } +} + +func Test_SetConfigOptionMessengerType(t *testing.T) { + opts := struct{ messengerType message.MessengerType }{} + + co := config.ConfigOption{ + Name: "message-sender-type", + OptType: types.String, + CustomSetValue: SetConfigOptionMessengerType, + ConfigKey: &opts.messengerType, + } + + testCases := []customSetterTestCase[message.MessengerType]{ + { + name: "returns an error if the messenger type is empty", + args: []string{}, + wantErrContains: `couldn't parse messenger type: invalid message sender type ""`, + }, + { + name: "returns an error if the messenger type is invalid", + args: []string{"--message-sender-type", "test"}, + wantErrContains: `couldn't parse messenger type: invalid message sender type "TEST"`, + }, + { + name: "πŸŽ‰ handles messenger type TWILIO_SMS (through CLI args)", + args: []string{"--message-sender-type", "TwIliO_sms"}, + wantResult: message.MessengerTypeTwilioSMS, + }, + { + name: "πŸŽ‰ handles messenger type TWILIO_SMS (through ENV vars)", + envValue: "TwIliO_sms", + wantResult: message.MessengerTypeTwilioSMS, + }, + { + name: "πŸŽ‰ handles messenger type AWS_SMS (through CLI args)", + args: []string{"--message-sender-type", "AWs_SMS"}, + wantResult: message.MessengerTypeAWSSMS, + }, + { + name: "πŸŽ‰ handles messenger type AWS_SMS (through ENV vars)", + envValue: "AWs_SMS", + wantResult: message.MessengerTypeAWSSMS, + }, + { + name: "πŸŽ‰ handles messenger type AWS_EMAIL (through CLI args)", + args: []string{"--message-sender-type", "AWS_EMAIL"}, + wantResult: message.MessengerTypeAWSEmail, + }, + { + name: "πŸŽ‰ handles messenger type AWS_EMAIL (through ENV vars)", + envValue: "AWS_EMAIL", + wantResult: message.MessengerTypeAWSEmail, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.messengerType = "" + customSetterTester[message.MessengerType](t, tc, co) + }) + } +} + +func Test_SetConfigOptionLogLevel(t *testing.T) { + opts := struct{ logrusLevel logrus.Level }{} + + co := config.ConfigOption{ + Name: "log-level", + OptType: types.String, + CustomSetValue: SetConfigOptionLogLevel, + ConfigKey: &opts.logrusLevel, + } + + testCases := []customSetterTestCase[logrus.Level]{ + { + name: "returns an error if the log level is empty", + args: []string{}, + wantErrContains: `couldn't parse log level: not a valid logrus Level: ""`, + }, + { + name: "returns an error if the log level is invalid", + args: []string{"--log-level", "test"}, + wantErrContains: `couldn't parse log level: not a valid logrus Level: "test"`, + }, + { + name: "πŸŽ‰ handles messenger type TRACE (through CLI args)", + args: []string{"--log-level", "TRACE"}, + wantResult: logrus.TraceLevel, + }, + { + name: "πŸŽ‰ handles messenger type TRACE (through ENV vars)", + envValue: "TRACE", + wantResult: logrus.TraceLevel, + }, + { + name: "πŸŽ‰ handles messenger type INFO (through CLI args)", + args: []string{"--log-level", "iNfO"}, + wantResult: logrus.InfoLevel, + }, + { + name: "πŸŽ‰ handles messenger type INFO (through ENV vars)", + envValue: "INFO", + wantResult: logrus.InfoLevel, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.logrusLevel = 0 + customSetterTester[logrus.Level](t, tc, co) + }) + } +} + +func Test_SetConfigOptionMetricType(t *testing.T) { + opts := struct{ metricType monitor.MetricType }{} + + co := config.ConfigOption{ + Name: "metrics-type", + OptType: types.String, + CustomSetValue: SetConfigOptionMetricType, + ConfigKey: &opts.metricType, + } + + testCases := []customSetterTestCase[monitor.MetricType]{ + { + name: "returns an error if the value is empty", + args: []string{}, + wantErrContains: `couldn't parse metric type: invalid metric type ""`, + }, + { + name: "returns an error if the value is not supported", + args: []string{"--metrics-type", "test"}, + wantErrContains: `couldn't parse metric type: invalid metric type "TEST"`, + }, + { + name: "πŸŽ‰ handles crash tracker type (through CLI args): PROMETHEUS", + args: []string{"--metrics-type", "PROMETHEUS"}, + wantResult: monitor.MetricTypePrometheus, + }, + { + name: "πŸŽ‰ handles crash tracker type (through ENV vars): PROMETHEUS", + envValue: "PROMETHEUS", + wantResult: monitor.MetricTypePrometheus, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.metricType = "" + customSetterTester[monitor.MetricType](t, tc, co) + }) + } +} + +func Test_SetConfigOptionCrashTrackerType(t *testing.T) { + opts := struct{ crashTrackerType crashtracker.CrashTrackerType }{} + + co := config.ConfigOption{ + Name: "crash-tracker-type", + OptType: types.String, + CustomSetValue: SetConfigOptionCrashTrackerType, + ConfigKey: &opts.crashTrackerType, + } + + testCases := []customSetterTestCase[crashtracker.CrashTrackerType]{ + { + name: "returns an error if the value is empty", + args: []string{}, + wantErrContains: `couldn't parse crash tracker type: invalid crash tracker type ""`, + }, + { + name: "returns an error if the value is not supported", + args: []string{"--crash-tracker-type", "test"}, + wantErrContains: `couldn't parse crash tracker type: invalid crash tracker type "TEST"`, + }, + { + name: "πŸŽ‰ handles crash tracker type (through CLI args): SENTRY", + args: []string{"--crash-tracker-type", "SeNtRy"}, + wantResult: crashtracker.CrashTrackerTypeSentry, + }, + { + name: "πŸŽ‰ handles crash tracker type (through ENV vars): SENTRY", + envValue: "SENTRY", + wantResult: crashtracker.CrashTrackerTypeSentry, + }, + { + name: "πŸŽ‰ handles crash tracker type (through CLI args): DRY_RUN", + args: []string{"--crash-tracker-type", "DRY_RUN"}, + wantResult: crashtracker.CrashTrackerTypeDryRun, + }, + { + name: "πŸŽ‰ handles crash tracker type (through ENV vars): DRY_RUN", + envValue: "DRY_RUN", + wantResult: crashtracker.CrashTrackerTypeDryRun, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.crashTrackerType = "" + customSetterTester[crashtracker.CrashTrackerType](t, tc, co) + }) + } +} + +func Test_SetConfigOptionEC256PublicKey(t *testing.T) { + opts := struct{ ec256PublicKey string }{} + + co := config.ConfigOption{ + Name: "ec256-public-key", + OptType: types.String, + CustomSetValue: SetConfigOptionEC256PublicKey, + ConfigKey: &opts.ec256PublicKey, + } + + expectedPublicKey := `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER88h7AiQyVDysRTxKvBB6CaiO/kS +cvGyimApUE/12gFhNTRf37SE19CSCllKxstnVFOpLLWB7Qu5OJ0Wvcz3hg== +-----END PUBLIC KEY-----` + + testCases := []customSetterTestCase[string]{ + { + name: "returns an error if the value is not a PEM string", + args: []string{"--ec256-public-key", "not-a-pem-string"}, + wantErrContains: "parsing EC256PublicKey: failed to decode PEM block containing public key", + }, + { + name: "returns an error if the value is not a x509 string", + args: []string{"--ec256-public-key", "-----BEGIN MY STRING-----\nYWJjZA==\n-----END MY STRING-----"}, + wantErrContains: "parsing EC256PublicKey: failed to parse x509 PKIX public key", + }, + { + name: "returns an error if the value is not a ECDSA public key", + args: []string{"--ec256-public-key", "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyNPqmozv8a2PnXHIkV+F\nmWMFy2YhOFzX12yzjjWkJ3rI9QSEomz4Unkwc6oYrnKEDYlnAgCiCqL2zPr5qNkX\nk5MPU87/wLgEqp7uAk0GkJZfrhJIYZ5AuG9+o69BNeQDEi7F3YdMJj9bvs2Ou1FN\n1zG/8HV969rJ/63fzWsqlNon1j4H5mJ0YbmVh/QLcYPmv7feFZGEj4OSZ4u+eJsw\nat5NPyhMgo6uB/goNS3fEY29UNvXoSIN3hnK3WSxQ79Rjn4V4so7ehxzCVPjnm/G\nFFTgY0hGBobmnxbjI08hEZmYKosjan4YqydGETjKR3UlhBx9y/eqqgL+opNJ8vJs\n2QIDAQAB\n-----END PUBLIC KEY-----"}, + wantErrContains: "parsing EC256PublicKey: public key is not of type ECDSA", + }, + { + name: "πŸŽ‰ handles EC256 public key through the CLI flag", + args: []string{"--ec256-public-key", expectedPublicKey}, + wantResult: expectedPublicKey, + }, + { + name: "πŸŽ‰ handles EC256 public key through the ENV vars", + envValue: expectedPublicKey, + wantResult: expectedPublicKey, + }, + { + name: "πŸŽ‰ handles EC256 public key through the ENV vars & inline line-breaks", + envValue: "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER88h7AiQyVDysRTxKvBB6CaiO/kS\ncvGyimApUE/12gFhNTRf37SE19CSCllKxstnVFOpLLWB7Qu5OJ0Wvcz3hg==\n-----END PUBLIC KEY-----", + wantResult: expectedPublicKey, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.ec256PublicKey = "" + customSetterTester[string](t, tc, co) + }) + } +} + +func Test_SetConfigOptionEC256PrivateKey(t *testing.T) { + opts := struct{ ec256PrivateKey string }{} + + co := config.ConfigOption{ + Name: "ec256-private-key", + OptType: types.String, + CustomSetValue: SetConfigOptionEC256PrivateKey, + ConfigKey: &opts.ec256PrivateKey, + } + + expectedPrivateKey := `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIqI1MzMZIw2pQDLx +Jn0+FcNT/hNjwtn2TW43710JKZqhRANCAARHzyHsCJDJUPKxFPEq8EHoJqI7+RJy +8bKKYClQT/XaAWE1NF/ftITX0JIKWUrGy2dUU6kstYHtC7k4nRa9zPeG +-----END PRIVATE KEY-----` + + testCases := []customSetterTestCase[string]{ + { + name: "returns an error if the value is not a PEM string", + args: []string{"--ec256-private-key", "not-a-pem-string"}, + wantErrContains: "parsing EC256PrivateKey: failed to decode PEM block containing private key", + }, + { + name: "returns an error if the value is not a x509 string", + args: []string{"--ec256-private-key", "-----BEGIN MY STRING-----\nYWJjZA==\n-----END MY STRING-----"}, + wantErrContains: "parsing EC256PrivateKey: failed to parse EC private key", + }, + { + name: "returns an error if the value is not a ECDSA private key", + args: []string{"--ec256-private-key", "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyNPqmozv8a2PnXHIkV+F\nmWMFy2YhOFzX12yzjjWkJ3rI9QSEomz4Unkwc6oYrnKEDYlnAgCiCqL2zPr5qNkX\nk5MPU87/wLgEqp7uAk0GkJZfrhJIYZ5AuG9+o69BNeQDEi7F3YdMJj9bvs2Ou1FN\n1zG/8HV969rJ/63fzWsqlNon1j4H5mJ0YbmVh/QLcYPmv7feFZGEj4OSZ4u+eJsw\nat5NPyhMgo6uB/goNS3fEY29UNvXoSIN3hnK3WSxQ79Rjn4V4so7ehxzCVPjnm/G\nFFTgY0hGBobmnxbjI08hEZmYKosjan4YqydGETjKR3UlhBx9y/eqqgL+opNJ8vJs\n2QIDAQAB\n-----END PUBLIC KEY-----"}, + wantErrContains: "parsing EC256PrivateKey: failed to parse EC private key", + }, + { + name: "πŸŽ‰ handles EC256 private key through the CLI flag", + args: []string{"--ec256-private-key", expectedPrivateKey}, + wantResult: expectedPrivateKey, + }, + { + name: "πŸŽ‰ handles EC256 private key through the ENV vars", + envValue: expectedPrivateKey, + wantResult: expectedPrivateKey, + }, + { + name: "πŸŽ‰ handles EC256 private key through the ENV vars & inline line-breaks", + envValue: `-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIqI1MzMZIw2pQDLx\nJn0+FcNT/hNjwtn2TW43710JKZqhRANCAARHzyHsCJDJUPKxFPEq8EHoJqI7+RJy\n8bKKYClQT/XaAWE1NF/ftITX0JIKWUrGy2dUU6kstYHtC7k4nRa9zPeG\n-----END PRIVATE KEY-----`, + wantResult: expectedPrivateKey, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.ec256PrivateKey = "" + customSetterTester[string](t, tc, co) + }) + } +} + +func Test_SetConfigOptionStellarPublicKey(t *testing.T) { + opts := struct{ sep10SigningPublicKey string }{} + + co := config.ConfigOption{ + Name: "sep10-signing-public-key", + OptType: types.String, + CustomSetValue: SetConfigOptionStellarPublicKey, + ConfigKey: &opts.sep10SigningPublicKey, + } + expectedPublicKey := "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + + testCases := []customSetterTestCase[string]{ + { + name: "returns an error if the public key is empty", + wantErrContains: "error validating public key: strkey is 0 bytes long; minimum valid length is 5", + }, + { + name: "returns an error if the public key is invalid", + args: []string{"--sep10-signing-public-key", "invalid_public_key"}, + wantErrContains: "error validating public key: base32 decode failed: illegal base32 data at input byte 18", + }, + { + name: "returns an error if the public key is invalid (private key instead)", + args: []string{"--sep10-signing-public-key", "SDISQRUPIHAO5WIIGY4QRDCINZSA44TX3OIIUK3C63NUKN5DABKEQ276"}, + wantErrContains: "error validating public key: invalid version byte", + }, + { + name: "πŸŽ‰ handles Stellar public key through the CLI flag", + args: []string{"--sep10-signing-public-key", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"}, + wantResult: expectedPublicKey, + }, + { + name: "πŸŽ‰ handles Stellar public key through the ENV vars", + envValue: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + wantResult: expectedPublicKey, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.sep10SigningPublicKey = "" + customSetterTester[string](t, tc, co) + }) + } +} + +func Test_SetConfigOptionStellarPrivateKey(t *testing.T) { + opts := struct{ sep10SigningPrivateKey string }{} + + co := config.ConfigOption{ + Name: "sep10-signing-private-key", + OptType: types.String, + CustomSetValue: SetConfigOptionStellarPrivateKey, + ConfigKey: &opts.sep10SigningPrivateKey, + } + expectedPrivateKey := "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5" + + testCases := []customSetterTestCase[string]{ + { + name: "returns an error if the private key is empty", + wantErrContains: `error validating private key: ""`, + }, + { + name: "returns an error if the private key is invalid", + args: []string{"--sep10-signing-private-key", "invalid_private_key"}, + wantErrContains: `error validating private key: "in...ey"`, + }, + { + name: "returns an error if the private key is invalid (public key instead)", + args: []string{"--sep10-signing-private-key", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"}, + wantErrContains: `error validating private key: "GA...7S"`, + }, + { + name: "πŸŽ‰ handles Stellar private key through the CLI flag", + args: []string{"--sep10-signing-private-key", "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5"}, + wantResult: expectedPrivateKey, + }, + { + name: "πŸŽ‰ handles Stellar private key through the ENV flag", + envValue: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + wantResult: expectedPrivateKey, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.sep10SigningPrivateKey = "" + customSetterTester[string](t, tc, co) + }) + } +} + +func Test_SetCorsAllowedOriginsFunc(t *testing.T) { + opts := struct{ corsAddressesFlag []string }{} + + co := config.ConfigOption{ + Name: "cors-allowed-origins", + OptType: types.String, + CustomSetValue: SetCorsAllowedOrigins, + ConfigKey: &opts.corsAddressesFlag, + Required: false, + } + + testCases := []customSetterTestCase[[]string]{ + { + name: "returns an error if the cors flag is empty", + args: []string{"--cors-allowed-origins", ""}, + wantErrContains: "cors allowed addresses cannot be empty", + }, + { + name: "returns an error if the cors flag results in an empty array", + args: []string{"--cors-allowed-origins", ","}, + wantErrContains: `error parsing cors addresses: parse ""`, + }, + { + name: "πŸŽ‰ handles one url successfully (from CLI args)", + args: []string{"--cors-allowed-origins", "https://foo.test/*"}, + wantResult: []string{"https://foo.test/*"}, + }, + { + name: "πŸŽ‰ handles two urls successfully (from CLI args)", + args: []string{"--cors-allowed-origins", "https://foo.test/*,https://bar.test/*"}, + wantResult: []string{"https://foo.test/*", "https://bar.test/*"}, + }, + { + name: "πŸŽ‰ handles one url successfully (from ENV vars)", + envValue: "https://foo.test/*", + wantResult: []string{"https://foo.test/*"}, + }, + { + name: "πŸŽ‰ handles two urls successfully (from ENV vars)", + envValue: "https://foo.test/*,https://bar.test/*", + wantResult: []string{"https://foo.test/*", "https://bar.test/*"}, + }, + { + name: `logs a warning when the "*" value is used`, + envValue: "*", + wantResult: []string{"*"}, + }, + } + + getEntries := log.DefaultLogger.StartTest(log.WarnLevel) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.corsAddressesFlag = nil + customSetterTester[[]string](t, tc, co) + }) + } + + entries := getEntries() + require.Len(t, entries, 1) + assert.Equal(t, `The value "*" for the CORS Allowed Origins is too permissive and not recommended.`, entries[0].Message) +} + +func Test_SetConfigOptionURLString(t *testing.T) { + opts := struct{ uiBaseURL string }{} + + co := config.ConfigOption{ + Name: "sdp-ui-base-url", + OptType: types.String, + CustomSetValue: SetConfigOptionURLString, + ConfigKey: &opts.uiBaseURL, + FlagDefault: "http://localhost:3000", + Required: false, + } + + testCases := []customSetterTestCase[string]{ + { + name: "returns an error if the ui base url flag is empty", + args: []string{"--sdp-ui-base-url", ""}, + wantErrContains: "ui base url cannot be empty", + }, + { + name: "πŸŽ‰ handles ui base url successfully (from CLI args)", + args: []string{"--sdp-ui-base-url", "https://sdp-ui.org"}, + wantResult: "https://sdp-ui.org", + }, + { + name: "πŸŽ‰ handles ui base url successfully (from ENV vars)", + envValue: "https://sdp-ui.org", + wantResult: "https://sdp-ui.org", + }, + { + name: "πŸŽ‰ handles ui base url DEFAULT value", + wantResult: "http://localhost:3000", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts.uiBaseURL = "" + customSetterTester[string](t, tc, co) + }) + } +} diff --git a/cmd/utils/shared_config_options.go b/cmd/utils/shared_config_options.go new file mode 100644 index 000000000..6d2f436f3 --- /dev/null +++ b/cmd/utils/shared_config_options.go @@ -0,0 +1,79 @@ +package utils + +import ( + "go/types" + + "github.com/stellar/go/support/config" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" +) + +// TwilioConfigOptions returns the config options for Twilio. Relevant for loading configs needed for the messenger type(s): `TWILIO_*`. +func TwilioConfigOptions(opts *message.MessengerOptions) []*config.ConfigOption { + return []*config.ConfigOption{ + { + Name: "twilio-account-sid", + Usage: "The SID of the Twilio account", + OptType: types.String, + ConfigKey: &opts.TwilioAccountSID, + Required: false, + }, + { + Name: "twilio-auth-token", + Usage: "The Auth Token of the Twilio account", + OptType: types.String, + ConfigKey: &opts.TwilioAuthToken, + Required: false, + }, + { + Name: "twilio-service-sid", + Usage: "The service ID used within Twilio to send messages", + OptType: types.String, + ConfigKey: &opts.TwilioServiceSID, + Required: false, + }, + } +} + +// AWSConfigOptions returns the config options for AWS. Relevant for loading configs needed for the messenger type(s): `AWS_*`. +func AWSConfigOptions(opts *message.MessengerOptions) []*config.ConfigOption { + return []*config.ConfigOption{ + // AWS + { + Name: "aws-access-key-id", + Usage: "The AWS access key ID", + OptType: types.String, + ConfigKey: &opts.AWSAccessKeyID, + Required: false, + }, + { + Name: "aws-secret-access-key", + Usage: "The AWS secret access key", + OptType: types.String, + ConfigKey: &opts.AWSSecretAccessKey, + Required: false, + }, + { + Name: "aws-region", + Usage: "The AWS region", + OptType: types.String, + ConfigKey: &opts.AWSRegion, + Required: false, + }, + // AWS SMS (SNS) + { + Name: "aws-sns-sender-id", + Usage: "The sender ID of the aws account sending the SMS message. Uses AWS SNS.", + OptType: types.String, + ConfigKey: &opts.AWSSNSSenderID, + Required: false, + }, + // AWS Email (SES) + { + Name: "aws-ses-sender-id", + Usage: "The email address that AWS will use to send emails. Uses AWS SES.", + OptType: types.String, + ConfigKey: &opts.AWSSESSenderID, + Required: false, + }, + } +} diff --git a/cmd/utils/test_helpers.go b/cmd/utils/test_helpers.go new file mode 100644 index 000000000..d184a9dba --- /dev/null +++ b/cmd/utils/test_helpers.go @@ -0,0 +1,17 @@ +package utils + +import ( + "os" + "strings" + "testing" +) + +// clearTestEnvironment removes all envs from the test environment. It's useful +// to make tests independent from the localhost environment variables. +func ClearTestEnvironment(t *testing.T) { + // remove all envs from tghe test environment + for _, env := range os.Environ() { + key := env[:strings.Index(env, "=")] + t.Setenv(key, "") + } +} diff --git a/dev/.env.example b/dev/.env.example new file mode 100644 index 000000000..3ca6d30a5 --- /dev/null +++ b/dev/.env.example @@ -0,0 +1,7 @@ +# Generate a new keypair for SEP-10 signing +SEP10_SIGNING_PUBLIC_KEY= +SEP10_SIGNING_PRIVATE_KEY= + +# Generate a new keypair for the distribution account +DISTRIBUTION_PUBLIC_KEY= +DISTRIBUTION_SEED= \ No newline at end of file diff --git a/dev/README.md b/dev/README.md new file mode 100644 index 000000000..a652a7d0d --- /dev/null +++ b/dev/README.md @@ -0,0 +1,124 @@ +# Quick Start Guide - First Disbursement + +## Table of Contents + - [Introduction](#introduction) + - [Prerequisites](#prerequisites) + - [Setup](#setup) + - [Build Docker Containers](#build-docker-containers) + - [Create an Owner SDP User](#create-an-owner-sdp-user) + - [Disbursement](#disbursement) + - [Create First Disbursement](#create-first-disbursement) + - [Deposit Money](#deposit-money) + - [Troubleshooting](#troubleshooting) + +## Introduction + +Follow these instructions to get started with the Stellar Disbursement Platform (SDP). + +## Prerequisites + +### Docker + +Make sure you have Docker installed on your system. If not, you can download it from [here](https://www.docker.com/products/docker-desktop). + +### Hosts + +Add the following two hosts to your `/etc/hosts` file: + +```sh +127.0.0.1 sdp-api +127.0.0.1 anchor-platform +``` + +### Stellar accounts +We will need to create and configure two Stellar accounts to be able to use the SDP. +* A Distribution account that will be used for sending funds to receivers. [Create and Fund a Distribution Account](https://developers.stellar.org/docs/stellar-disbursement-platform/getting-started#create-and-fund-a-distribution-account) +* A SEP-10 account that will be used for authentication. It can be created the same way as the distribution account but it doesn't need to be funded. + +The public and private key of these two accounts will be used to configure the SDP in the next step. + +## Setup + +### Build Docker Containers + +1. Navigate to the `dev` directory from the terminal: +```sh +cd dev +``` + +2. Create a `.env` file in the `dev` directory by copying the `.env.example` file: +```sh +cp .env.example .env +``` + +3. Update the `.env` file with the public and private keys of the two accounts created in the previous step. + +4. Execute the following command to create all the necessary Docker containers needed to run SDP: +```sh +docker-compose up +``` + +This will spin up the following services: + +- `sdp_v2_database`: The main SDP and TSS database. +- `anchor-platform-postgres-db`: Database used by the anchor platform. +- `anchor-platform`: A local instance of the anchor platform. +- `sdp-api`: SDP service running on port `8000`. +- `sdp-tss`: Transaction Submission service. +- `sdp-frontend`: SDP frontend service running on port `3000`. + +### Create an Owner SDP User + +Open a terminal for the `sdp-api` container and run the following command to create an owner user: + +```sh +docker exec -it sdp-api bash # Or use Docker Desktop to open terminal +./stellar-disbursement-platform auth add-user owner@stellar.org joe yabuki --password --owner --roles owner +``` + +You will be prompted to enter a password for the user. Be sure to remember it as it will be required for future authentications. + +## Disbursement + +### Create First Disbursement + +Navigate to the frontend service by opening a browser and going to [localhost:3000](http://localhost:3000). + +- Click `New Disbursement+` on the Dashboard screen. +- Use `Demo Wallet` as your wallet. +- Upload a disbursement file. A sample file is available `./dev/sample/sample-disbursement.csv`. Make sure to update the invalid phone number before using it. +- Finally, confirm the disbursement. + +### Deposit Money + +To deposit money into your account: + +- Access [https://demo-wallet.stellar.org/](https://demo-wallet.stellar.org/) in your browser. +- Click on `Generate Keypair for new account` to create a new testnet receiver account. Make sure to save your public key & secret. +- Add an Asset with the following information: + - Asset Code: `USDC` + - Anchor Home Domain: `localhost:8080` + - Issuer Public Key: `GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5` +- Click `Create Account` (in front of public key) and add Trustline for USDC. +- For USDC, select `SEP-24 Deposit`. +- In the new window, enter the phone number from the disbursement CSV. +- Enter the passcode. You can use `000000` passcode or find the actual passcode in the `sdp-api` container logs. +- Enter the birthday that matches the phone number in the CSV. +- Keep an eye on the dashboard until the payment status reaches `Success`. If everything was set up correctly, your money should be disbursed successfully. + +## Troubleshooting + +### Distribution account out of funds + +Payments will start failing if the distribution account runs out of funds. To fix this, you can either write a script that funds the distribution account or use the tools +available to add more funds to the distribution account by following these steps: + +- Find the distribution account public key in `dev/docker-compose.yml` under the variable `DISTRIBUTION_PUBLIC_KEY` +- Access [https://horizon-testnet.stellar.org/accounts/:accountId](https://horizon-testnet.stellar.org/accounts/GARGKDIDH7WMKV5WWPK4BH4CKEQIZGWUCA4EUXCY5VICHTHLEBXVNVMW) in your browser and check the balance. +- If the balance is indeed low, you can add more funds by creating a new account and sending funds to the distribution account. + - Access [https://demo-wallet.stellar.org/](https://demo-wallet.stellar.org/) in your browser. + - Click on `Generate Keypair for new account` to create a new testnet account. Your account comes with 10,000 XLM. + - Click on `Send` and enter the distribution account public key and the amount you want to send. + - Using Freighter or Stellar Laboratory, swap the XLM for USDC. + +You can also just use the newly created account as the distribution account by updating the `DISTRIBUTION_PUBLIC_KEY` variable in `dev/docker-compose.yml` and restarting the `sdp-api` container. diff --git a/dev/docker-compose-frontend.yml b/dev/docker-compose-frontend.yml new file mode 100644 index 000000000..27d2d4f1e --- /dev/null +++ b/dev/docker-compose-frontend.yml @@ -0,0 +1,12 @@ +version: '3.8' +services: + sdp-frontend: + container_name: sdp-frontend + image: stellar/stellar-disbursement-platform-frontend:edge + ports: + - "3000:80" + volumes: + - ./env-config.js:/usr/share/nginx/html/settings/env-config.js + depends_on: + - db + - sdp-api \ No newline at end of file diff --git a/dev/docker-compose-sdp-anchor.yml b/dev/docker-compose-sdp-anchor.yml new file mode 100644 index 000000000..7a5867144 --- /dev/null +++ b/dev/docker-compose-sdp-anchor.yml @@ -0,0 +1,152 @@ +version: '3.8' +services: + db: + container_name: sdp_v2_database + image: postgres:14-alpine + environment: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: sdp + PGDATA: /data/postgres + ports: + - "5432:5432" + volumes: + - postgres-db:/data/postgres + + sdp-api: + container_name: sdp-api + image: stellar/sdp-v2:latest + build: + context: ../ + dockerfile: Dockerfile + ports: + - "8000:8000" + environment: + BASE_URL: http://localhost:8000 + DATABASE_URL: postgres://postgres@db:5432/postgres?sslmode=disable + ENVIRONMENT: localhost + LOG_LEVEL: TRACE + PORT: "8000" + METRICS_PORT: "8002" + METRICS_TYPE: PROMETHEUS + EMAIL_SENDER_TYPE: DRY_RUN + SMS_SENDER_TYPE: DRY_RUN + NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + EC256_PUBLIC_KEY: "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEJ3HNphPAEKHvtRjsl5Kjwc9tTMqS\n2pmYNybrLsxZ6cuQvg2yiEoXZixP2cJ77csHClXC6cb1wQp/BNGDvGKoPg==\n-----END PUBLIC KEY-----" + SEP10_SIGNING_PUBLIC_KEY: ${SEP10_SIGNING_PUBLIC_KEY} + ANCHOR_PLATFORM_BASE_SEP_URL: http://localhost:8080 + ANCHOR_PLATFORM_BASE_PLATFORM_URL: http://anchor-platform:8085 + DISTRIBUTION_PUBLIC_KEY: ${DISTRIBUTION_PUBLIC_KEY} + DISTRIBUTION_SEED: ${DISTRIBUTION_SEED} + RECAPTCHA_SITE_KEY: 6LeIxAcTAAAAAJcZVRqyHh71UMIEGNQ_MXjiZKhI + CORS_ALLOWED_ORIGINS: http://localhost:3000 + ENABLE_MFA: "false" + ENABLE_RECAPTCHA: "false" + + # secrets: + AWS_ACCESS_KEY_ID: MY_AWS_ACCESS_KEY_ID + AWS_REGION: MY_AWS_REGION + AWS_SECRET_ACCESS_KEY: MY_AWS_SECRET_ACCESS_KEY + AWS_SES_SENDER_ID: MY_AWS_SES_SENDER_ID + TWILIO_ACCOUNT_SID: MY_TWILIO_ACCOUNT_SID + TWILIO_AUTH_TOKEN: MY_TWILIO_AUTH_TOKEN + TWILIO_SERVICE_SID: MY_TWILIO_SERVICE_SID + EC256_PRIVATE_KEY: "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgdo6o+tdFkF94B7z8\nnoybH6/zO3PryLLjLbj54/zOi4WhRANCAAQncc2mE8AQoe+1GOyXkqPBz21MypLa\nmZg3JusuzFnpy5C+DbKIShdmLE/ZwnvtywcKVcLpxvXBCn8E0YO8Yqg+\n-----END PRIVATE KEY-----" + SEP10_SIGNING_PRIVATE_KEY: ${SEP10_SIGNING_PRIVATE_KEY} + SEP24_JWT_SECRET: jwt_secret_1234567890 + RECAPTCHA_SITE_SECRET_KEY: 6LeIxAcTAAAAAGG-vFI1TnRWxMZNFuojJ4WifJWe + ANCHOR_PLATFORM_OUTGOING_JWT_SECRET: mySdpToAnchorPlatformSecret + entrypoint: "" + command: + - sh + - -c + - | + sleep 5 + ./stellar-disbursement-platform db migrate up + ./stellar-disbursement-platform db auth migrate up + ./stellar-disbursement-platform db setup-for-network + ./stellar-disbursement-platform serve + depends_on: + - db + + db-anchor-platform: + container_name: anchor-platform-postgres-db + image: postgres:14-alpine + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + PGPORT: 5433 + ports: + - "5433:5433" + volumes: + - postgres-ap-db:/data/postgres + + anchor-platform: + container_name: anchor-platform + image: stellar/anchor-platform:2.1.3 + command: --sep-server --platform-server --platform linux/amd64 + ports: + - "8080:8080" # sep-server + - "8085:8085" # platform-server + - "8082:8082" # metrics + depends_on: + - db-anchor-platform + environment: + HOST_URL: http://localhost:8080 + SEP_SERVER_PORT: 8080 + CALLBACK_API_BASE_URL: http://sdp-api:8000 + CALLBACK_API_AUTH_TYPE: none # TODO: update to jwt later + PLATFORM_SERVER_AUTH_TYPE: JWT + APP_LOGGING_LEVEL: INFO + DATA_TYPE: postgres + DATA_SERVER: db-anchor-platform:5433 + DATA_DATABASE: postgres + DATA_FLYWAY_ENABLED: "true" + DATA_DDL_AUTO: update + METRICS_ENABLED: "false" # Metrics would be available at port 8082 + METRICS_EXTRAS_ENABLED: "false" + SEP10_ENABLED: "true" + SEP10_HOME_DOMAIN: localhost:8080 + SEP24_ENABLED: "true" + SEP24_INTERACTIVE_URL_BASE_URL: http://localhost:8000/wallet-registration/start + SEP24_INTERACTIVE_URL_JWT_EXPIRATION: 1800 # 1800 seconds is 30 minutes + SEP24_MORE_INFO_URL_BASE_URL: http://localhost:8000/wallet-registration/start + SEP1_ENABLED: "true" + SEP1_TOML_TYPE: url + SEP1_TOML_VALUE: http://sdp-api:8000/.well-known/stellar.toml + ASSETS_TYPE: json + ASSETS_VALUE: | + { + "assets": [ + { + "sep24_enabled": true, + "schema": "stellar", + "code": "USDC", + "issuer": "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "distribution_account": "${DISTRIBUTION_PUBLIC_KEY}", + "significant_decimals": 7, + "deposit": { + "enabled": true, + "fee_minimum": 0, + "fee_percent": 0, + "min_amount": 1, + "max_amount": 10000 + }, + "withdraw": {"enabled": false} + } + ] + } + + # secrets: + SECRET_DATA_USERNAME: postgres + SECRET_DATA_PASSWORD: postgres + SECRET_PLATFORM_API_AUTH_SECRET: mySdpToAnchorPlatformSecret + SECRET_SEP10_JWT_SECRET: jwt_secret_1234567890 + SECRET_SEP10_SIGNING_SEED: ${SEP10_SIGNING_PRIVATE_KEY} + SECRET_SEP24_INTERACTIVE_URL_JWT_SECRET: jwt_secret_1234567890 + SECRET_SEP24_MORE_INFO_URL_JWT_SECRET: jwt_secret_1234567890 +volumes: + postgres-db: + driver: local + postgres-ap-db: + driver: local diff --git a/dev/docker-compose-tss.yml b/dev/docker-compose-tss.yml new file mode 100644 index 000000000..cadcd2127 --- /dev/null +++ b/dev/docker-compose-tss.yml @@ -0,0 +1,33 @@ +version: '3.8' +services: + sdp-tss: + container_name: sdp-tss + image: stellar/sdp-v2:latest + build: + context: ../ + dockerfile: Dockerfile + ports: + - "9000:9000" + environment: + DATABASE_URL: postgres://postgres@db:5432/postgres?sslmode=disable + NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + HORIZON_URL: "https://horizon-testnet.stellar.org" + NUM_CHANNEL_ACCOUNTS: "3" + MAX_BASE_FEE: "100" + MOCK: "false" + TSS_METRICS_PORT: "9002" + TSS_METRICS_TYPE: "TSS_PROMETHEUS" + DISTRIBUTION_SEED: ${DISTRIBUTION_SEED} + depends_on: + - db + - sdp-api + entrypoint: "" + command: + - sh + - -c + - | + sleep 10 + ./stellar-disbursement-platform channel-accounts verify --delete-invalid-accounts && + ./stellar-disbursement-platform channel-accounts ensure --num-channel-accounts-ensure 1 + ./stellar-disbursement-platform tss + diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml new file mode 100644 index 000000000..2b20d794c --- /dev/null +++ b/dev/docker-compose.yml @@ -0,0 +1,36 @@ +version: '3' +services: + db: + extends: + file: docker-compose-sdp-anchor.yml + service: db + volumes: + - postgres-db:/data/postgres + sdp-api: + extends: + file: docker-compose-sdp-anchor.yml + service: sdp-api + db-anchor-platform: + extends: + file: docker-compose-sdp-anchor.yml + service: db-anchor-platform + volumes: + - postgres-ap-db:/data/postgres + anchor-platform: + extends: + file: docker-compose-sdp-anchor.yml + service: anchor-platform + sdp-tss: + extends: + file: docker-compose-tss.yml + service: sdp-tss + sdp-frontend: + extends: + file: docker-compose-frontend.yml + service: sdp-frontend +volumes: + postgres-db: + driver: local + postgres-ap-db: + driver: local + diff --git a/dev/env-config.js b/dev/env-config.js new file mode 100644 index 000000000..215b3fe73 --- /dev/null +++ b/dev/env-config.js @@ -0,0 +1,7 @@ +window._env_ = { + API_URL: "http://localhost:8000", + STELLAR_EXPERT_URL: "https://stellar.expert/explorer/testnet/tx", + HORIZON_URL: "https://horizon-testnet.stellar.org", + USDC_ASSET_ISSUER: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + RECAPTCHA_SITE_KEY: "6LeIxAcTAAAAAJcZVRqyHh71UMIEGNQ_MXjiZKhI" +}; \ No newline at end of file diff --git a/dev/main.sh b/dev/main.sh new file mode 100755 index 000000000..3c16db169 --- /dev/null +++ b/dev/main.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# This script is used to locally start the integration between SDP and AnchorPlatform for the SEP-24 deposit flow, needed for registering users. +set -eu + +export DIVIDER="----------------------------------------" + +# prepare +echo "====> πŸ‘€Step 1: start preparation" +docker ps -aq | xargs docker stop | xargs docker rm +echo "====> βœ…Step 1: finish preparation" + +# Run docker compose +echo $DIVIDER +echo "====> πŸ‘€Step 2: start calling docker compose up" +docker-compose -f docker-compose-sdp-anchor.yml down && docker-compose -f docker-compose-sdp-anchor.yml up --abort-on-container-exit +echo "====> βœ…Step 2: finish calling docker-compose up" + +echo $DIVIDER +echo "πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰ SUCCESS! πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰" diff --git a/dev/sample/sample-disbursement.csv b/dev/sample/sample-disbursement.csv new file mode 100644 index 000000000..151fa6dca --- /dev/null +++ b/dev/sample/sample-disbursement.csv @@ -0,0 +1,2 @@ +phone,id,amount,verification ++15550111111,4ba1,2,1987-12-01 \ No newline at end of file diff --git a/docs/images/admin_schema.png b/docs/images/admin_schema.png new file mode 100644 index 000000000..747ec838a Binary files /dev/null and b/docs/images/admin_schema.png differ diff --git a/docs/images/core_schema.png b/docs/images/core_schema.png new file mode 100644 index 000000000..4a7305eb9 Binary files /dev/null and b/docs/images/core_schema.png differ diff --git a/docs/images/high_level_architecture.png b/docs/images/high_level_architecture.png new file mode 100644 index 000000000..affb5b21f Binary files /dev/null and b/docs/images/high_level_architecture.png differ diff --git a/docs/images/tss_schema.png b/docs/images/tss_schema.png new file mode 100644 index 000000000..22d1fe823 Binary files /dev/null and b/docs/images/tss_schema.png differ diff --git a/go.list b/go.list new file mode 100644 index 000000000..425f953c1 --- /dev/null +++ b/go.list @@ -0,0 +1,292 @@ +github.com/stellar/stellar-disbursement-platform-backend +cloud.google.com/go v0.110.0 +cloud.google.com/go/bigquery v1.8.0 +cloud.google.com/go/compute v1.19.0 +cloud.google.com/go/compute/metadata v0.2.3 +cloud.google.com/go/datastore v1.1.0 +cloud.google.com/go/firestore v1.9.0 +cloud.google.com/go/longrunning v0.4.1 +cloud.google.com/go/pubsub v1.3.1 +cloud.google.com/go/storage v1.14.0 +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9 +firebase.google.com/go v3.12.0+incompatible +github.com/BurntSushi/toml v1.3.2 +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802 +github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 +github.com/CloudyKit/jet/v6 v6.2.0 +github.com/Joker/jade v1.1.3 +github.com/Masterminds/goutils v1.1.1 +github.com/Masterminds/semver/v3 v3.2.0 +github.com/Masterminds/sprig/v3 v3.2.3 +github.com/Masterminds/squirrel v1.5.0 +github.com/Microsoft/go-winio v0.4.14 +github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 +github.com/adjust/goautoneg v0.0.0-20150426214442-d788f35a0315 +github.com/ajg/form v0.0.0-20160822230020-523a5da1a92f +github.com/alecthomas/kingpin/v2 v2.3.2 +github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 +github.com/andybalholm/brotli v1.0.5 +github.com/armon/go-metrics v0.4.0 +github.com/armon/go-radix v1.0.0 +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 +github.com/aws/aws-sdk-go v1.44.321 +github.com/aymerick/douceur v0.2.0 +github.com/beevik/etree v1.1.0 +github.com/beorn7/perks v1.0.1 +github.com/bgentry/speakeasy v0.1.0 +github.com/buger/goreplay v1.3.2 +github.com/census-instrumentation/opencensus-proto v0.2.1 +github.com/cespare/xxhash/v2 v2.2.0 +github.com/chzyer/logex v1.2.1 +github.com/chzyer/readline v1.5.1 +github.com/chzyer/test v1.0.0 +github.com/client9/misspell v0.3.4 +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 +github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 +github.com/coreos/go-semver v0.3.0 +github.com/coreos/go-systemd/v22 v22.3.2 +github.com/cpuguy83/go-md2man/v2 v2.0.2 +github.com/creack/pty v1.1.9 +github.com/davecgh/go-spew v1.1.1 +github.com/denisenkom/go-mssqldb v0.9.0 +github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385 +github.com/elazarl/go-bindata-assetfs v1.0.0 +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad +github.com/envoyproxy/protoc-gen-validate v0.1.0 +github.com/fatih/color v1.13.0 +github.com/fatih/structs v1.1.0 +github.com/flosch/pongo2/v4 v4.0.2 +github.com/frankban/quicktest v1.14.4 +github.com/fsnotify/fsnotify v1.6.0 +github.com/gavv/monotime v0.0.0-20161010190848-47d58efa6955 +github.com/getsentry/raven-go v0.0.0-20160805001729-c9d3cc542ad1 +github.com/getsentry/sentry-go v0.23.0 +github.com/gin-contrib/sse v0.1.0 +github.com/gin-gonic/gin v1.8.1 +github.com/go-chi/chi v4.1.2+incompatible +github.com/go-chi/chi/v5 v5.0.10 +github.com/go-errors/errors v1.4.2 +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1 +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4 +github.com/go-gorp/gorp/v3 v3.1.0 +github.com/go-kit/log v0.2.1 +github.com/go-logfmt/logfmt v0.5.1 +github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab +github.com/go-playground/locales v0.14.0 +github.com/go-playground/universal-translator v0.18.0 +github.com/go-playground/validator/v10 v10.11.1 +github.com/go-sql-driver/mysql v1.6.0 +github.com/gobuffalo/logger v1.0.6 +github.com/gobuffalo/packd v1.0.1 +github.com/gobuffalo/packr v1.12.1 +github.com/gobuffalo/packr/v2 v2.8.3 +github.com/gocarina/gocsv v0.0.0-20230616125104-99d496ca653d +github.com/goccy/go-json v0.9.11 +github.com/godror/godror v0.24.2 +github.com/gogo/protobuf v1.3.2 +github.com/golang-jwt/jwt v3.2.2+incompatible +github.com/golang-jwt/jwt/v4 v4.5.0 +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da +github.com/golang/mock v1.6.0 +github.com/golang/protobuf v1.5.3 +github.com/golang/snappy v0.0.4 +github.com/google/btree v1.0.0 +github.com/google/go-cmp v0.5.9 +github.com/google/go-querystring v0.0.0-20160401233042-9235644dd9e5 +github.com/google/martian v2.1.0+incompatible +github.com/google/martian/v3 v3.1.0 +github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2 +github.com/google/renameio v0.1.0 +github.com/google/s2a-go v0.1.3 +github.com/google/uuid v1.3.0 +github.com/googleapis/enterprise-certificate-proxy v0.2.3 +github.com/googleapis/gax-go/v2 v2.8.0 +github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8 +github.com/gorilla/css v1.0.0 +github.com/gorilla/schema v1.2.0 +github.com/graph-gophers/graphql-go v1.3.0 +github.com/guregu/null v2.1.3-0.20151024101046-79c5bd36b615+incompatible +github.com/hashicorp/consul/api v1.20.0 +github.com/hashicorp/errwrap v1.1.0 +github.com/hashicorp/go-cleanhttp v0.5.2 +github.com/hashicorp/go-hclog v1.2.0 +github.com/hashicorp/go-immutable-radix v1.3.1 +github.com/hashicorp/go-multierror v1.1.1 +github.com/hashicorp/go-rootcerts v1.0.2 +github.com/hashicorp/golang-lru v0.5.4 +github.com/hashicorp/hcl v1.0.0 +github.com/hashicorp/serf v0.10.1 +github.com/holiman/uint256 v1.2.0 +github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c +github.com/hpcloud/tail v1.0.0 +github.com/huandu/xstrings v1.4.0 +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639 +github.com/imdario/mergo v0.3.13 +github.com/imkira/go-interpol v1.1.0 +github.com/inconshreveable/mousetrap v1.1.0 +github.com/iris-contrib/schema v0.0.6 +github.com/jarcoal/httpmock v0.0.0-20161210151336-4442edb3db31 +github.com/jmespath/go-jmespath v0.4.0 +github.com/jmespath/go-jmespath/internal/testify v1.5.1 +github.com/jmoiron/sqlx v1.3.5 +github.com/josharian/intern v1.0.0 +github.com/jpillora/backoff v1.0.0 +github.com/json-iterator/go v1.1.12 +github.com/jstemmer/go-junit-report v0.9.1 +github.com/julienschmidt/httprouter v1.3.0 +github.com/karrick/godirwalk v1.16.1 +github.com/kataras/blocks v0.0.7 +github.com/kataras/golog v0.1.8 +github.com/kataras/iris/v12 v12.2.0 +github.com/kataras/pio v0.0.11 +github.com/kataras/sitemap v0.0.6 +github.com/kataras/tunnel v0.0.4 +github.com/kisielk/gotool v1.0.0 +github.com/klauspost/compress v1.16.0 +github.com/konsorten/go-windows-terminal-sequences v1.0.1 +github.com/kr/fs v0.1.0 +github.com/kr/pretty v0.3.1 +github.com/kr/pty v1.1.1 +github.com/kr/text v0.2.0 +github.com/labstack/echo/v4 v4.10.0 +github.com/labstack/gommon v0.4.0 +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 +github.com/leodido/go-urn v1.2.1 +github.com/lib/pq v1.10.9 +github.com/localtunnel/go-localtunnel v0.0.0-20170326223115-8a804488f275 +github.com/magiconair/properties v1.8.7 +github.com/mailgun/raymond/v2 v2.0.48 +github.com/mailru/easyjson v0.7.7 +github.com/manifoldco/promptui v0.9.0 +github.com/manucorporat/sse v0.0.0-20160126180136-ee05b128a739 +github.com/markbates/errx v1.1.0 +github.com/markbates/oncer v1.0.0 +github.com/markbates/safe v1.0.1 +github.com/mattn/go-colorable v0.1.13 +github.com/mattn/go-isatty v0.0.17 +github.com/mattn/go-oci8 v0.1.1 +github.com/mattn/go-runewidth v0.0.9 +github.com/mattn/go-sqlite3 v1.14.15 +github.com/matttproud/golang_protobuf_extensions v1.0.4 +github.com/microcosm-cc/bluemonday v1.0.23 +github.com/mitchellh/cli v1.1.5 +github.com/mitchellh/copystructure v1.2.0 +github.com/mitchellh/go-homedir v1.1.0 +github.com/mitchellh/mapstructure v1.5.0 +github.com/mitchellh/reflectwalk v1.0.2 +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd +github.com/modern-go/reflect2 v1.0.2 +github.com/moul/http2curl v0.0.0-20161031194548-4e24498b31db +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f +github.com/nelsam/hel/v2 v2.3.3 +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e +github.com/nyaruka/phonenumbers v1.1.8 +github.com/olekukonko/tablewriter v0.0.5 +github.com/onsi/ginkgo v1.7.0 +github.com/onsi/gomega v1.4.3 +github.com/opentracing/opentracing-go v1.1.0 +github.com/pelletier/go-toml v1.9.0 +github.com/pelletier/go-toml/v2 v2.0.9 +github.com/pingcap/errors v0.11.4 +github.com/pkg/errors v0.9.1 +github.com/pkg/sftp v1.13.1 +github.com/pmezard/go-difflib v1.0.0 +github.com/posener/complete v1.2.3 +github.com/poy/onpar v1.1.2 +github.com/prometheus/client_golang v1.16.0 +github.com/prometheus/client_model v0.4.0 +github.com/prometheus/common v0.44.0 +github.com/prometheus/procfs v0.11.1 +github.com/rogpeppe/go-internal v1.10.0 +github.com/rs/cors v1.9.0 +github.com/rs/xhandler v0.0.0-20160618193221-ed27b6fd6521 +github.com/rubenv/sql-migrate v1.5.2 +github.com/russross/blackfriday/v2 v2.1.0 +github.com/sagikazarmark/crypt v0.10.0 +github.com/schollz/closestmatch v2.1.0+incompatible +github.com/segmentio/go-loggly v0.5.1-0.20171222203950-eb91657e62b2 +github.com/sergi/go-diff v0.0.0-20161205080420-83532ca1c1ca +github.com/shopspring/decimal v1.3.1 +github.com/shurcooL/httpfs v0.0.0-20190707220628-8d4bc4ba7749 +github.com/sirupsen/logrus v1.9.3 +github.com/spf13/afero v1.9.5 +github.com/spf13/cast v1.5.1 +github.com/spf13/cobra v1.7.0 +github.com/spf13/jwalterweatherman v1.1.0 +github.com/spf13/pflag v1.0.5 +github.com/spf13/viper v1.16.0 +github.com/stellar/go v0.0.0-20230810175703-9c94bc588b15 +github.com/stellar/go-xdr v0.0.0-20211103144802-8017fc4bdfee +github.com/stellar/throttled v2.2.3-0.20190823235211-89d75816f59d+incompatible +github.com/stretchr/objx v0.5.1 +github.com/stretchr/testify v1.8.4 +github.com/subosito/gotenv v1.4.2 +github.com/tdewolff/minify/v2 v2.12.4 +github.com/tdewolff/parse/v2 v2.6.4 +github.com/twilio/twilio-go v1.11.0 +github.com/tyler-smith/go-bip39 v0.0.0-20180618194314-52158e4697b8 +github.com/ugorji/go/codec v1.2.7 +github.com/urfave/negroni v1.0.0 +github.com/valyala/bytebufferpool v1.0.0 +github.com/valyala/fasthttp v1.40.0 +github.com/valyala/fasttemplate v1.2.2 +github.com/vmihailenco/msgpack/v5 v5.3.5 +github.com/vmihailenco/tagparser/v2 v2.0.0 +github.com/xdrpp/goxdr v0.1.1 +github.com/xeipuuv/gojsonpointer v0.0.0-20151027082146-e0fe6f683076 +github.com/xeipuuv/gojsonreference v0.0.0-20150808065054-e02fc20de94c +github.com/xeipuuv/gojsonschema v0.0.0-20161231055540-f06f290571ce +github.com/xhit/go-str2duration/v2 v2.1.0 +github.com/yalp/jsonpath v0.0.0-20150812003900-31a79c7593bb +github.com/yosssi/ace v0.0.5 +github.com/yudai/gojsondiff v0.0.0-20170107030110-7b1b7adf999d +github.com/yudai/golcs v0.0.0-20150405163532-d1c525dea8ce +github.com/yudai/pp v2.0.1+incompatible +github.com/yuin/goldmark v1.4.13 +github.com/ziutek/mymysql v1.5.4 +go.etcd.io/etcd/api/v3 v3.5.9 +go.etcd.io/etcd/client/pkg/v3 v3.5.9 +go.etcd.io/etcd/client/v2 v2.305.7 +go.etcd.io/etcd/client/v3 v3.5.9 +go.opencensus.io v0.24.0 +go.uber.org/atomic v1.9.0 +go.uber.org/multierr v1.8.0 +go.uber.org/zap v1.21.0 +golang.org/x/crypto v0.12.0 +golang.org/x/exp v0.0.0-20230810033253-352e893a4cad +golang.org/x/image v0.0.0-20190802002840-cff245a6509b +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 +golang.org/x/mod v0.11.0 +golang.org/x/net v0.14.0 +golang.org/x/oauth2 v0.8.0 +golang.org/x/sync v0.3.0 +golang.org/x/sys v0.11.0 +golang.org/x/term v0.11.0 +golang.org/x/text v0.12.0 +golang.org/x/time v0.3.0 +golang.org/x/tools v0.6.0 +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 +google.golang.org/api v0.122.0 +google.golang.org/appengine v1.6.7 +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 +google.golang.org/grpc v1.55.0 +google.golang.org/protobuf v1.31.0 +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c +gopkg.in/errgo.v2 v2.1.0 +gopkg.in/fsnotify.v1 v1.4.7 +gopkg.in/gavv/httpexpect.v1 v1.0.0-20170111145843-40724cf1e4a0 +gopkg.in/gorp.v1 v1.7.1 +gopkg.in/ini.v1 v1.67.0 +gopkg.in/square/go-jose.v2 v2.4.1 +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 +gopkg.in/tylerb/graceful.v1 v1.2.15 +gopkg.in/yaml.v2 v2.4.0 +gopkg.in/yaml.v3 v3.0.1 +honnef.co/go/tools v0.0.1-2020.1.4 +rsc.io/binaryregexp v0.2.0 +rsc.io/quote/v3 v3.1.0 +rsc.io/sampler v1.3.0 diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..39b9ca477 --- /dev/null +++ b/go.mod @@ -0,0 +1,74 @@ +module github.com/stellar/stellar-disbursement-platform-backend + +go 1.19 + +require ( + github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 + github.com/aws/aws-sdk-go v1.44.321 + github.com/getsentry/sentry-go v0.23.0 + github.com/go-chi/chi v4.1.2+incompatible + github.com/go-chi/chi/v5 v5.0.10 + github.com/gocarina/gocsv v0.0.0-20230616125104-99d496ca653d + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/google/uuid v1.3.0 + github.com/jmoiron/sqlx v1.3.5 + github.com/lib/pq v1.10.9 + github.com/manifoldco/promptui v0.9.0 + github.com/nyaruka/phonenumbers v1.1.8 + github.com/prometheus/client_golang v1.16.0 + github.com/rs/cors v1.9.0 + github.com/rubenv/sql-migrate v1.5.2 + github.com/sirupsen/logrus v1.9.3 + github.com/spf13/cobra v1.7.0 + github.com/spf13/viper v1.16.0 + github.com/stellar/go v0.0.0-20230810175703-9c94bc588b15 + github.com/stretchr/testify v1.8.4 + github.com/twilio/twilio-go v1.11.0 + golang.org/x/crypto v0.12.0 + golang.org/x/exp v0.0.0-20230810033253-352e893a4cad +) + +require ( + github.com/BurntSushi/toml v1.3.2 // indirect + github.com/Masterminds/squirrel v1.5.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chzyer/readline v1.5.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/go-errors/errors v1.4.2 // indirect + github.com/go-gorp/gorp/v3 v3.1.0 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/golang/protobuf v1.5.3 // indirect + github.com/gorilla/schema v1.2.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect + github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/manucorporat/sse v0.0.0-20160126180136-ee05b128a739 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.4.0 // indirect + github.com/prometheus/common v0.44.0 // indirect + github.com/prometheus/procfs v0.11.1 // indirect + github.com/segmentio/go-loggly v0.5.1-0.20171222203950-eb91657e62b2 // indirect + github.com/spf13/afero v1.9.5 // indirect + github.com/spf13/cast v1.5.1 // indirect + github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/stellar/go-xdr v0.0.0-20211103144802-8017fc4bdfee // indirect + github.com/stretchr/objx v0.5.1 // indirect + github.com/subosito/gotenv v1.4.2 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/sys v0.11.0 // indirect + golang.org/x/text v0.12.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/tylerb/graceful.v1 v1.2.15 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..320839c2b --- /dev/null +++ b/go.sum @@ -0,0 +1,661 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= +cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.44.3/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= +cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= +cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= +cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= +cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= +cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= +cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= +cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= +cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= +cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= +cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI= +cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmWk= +cloud.google.com/go v0.75.0/go.mod h1:VGuuCn7PG0dwsd5XPVm2Mm3wlh3EL55/79EKB6hlPTY= +cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= +cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= +cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= +cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= +cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= +cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= +cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= +cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= +cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= +cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= +cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= +cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= +cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= +cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= +cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= +github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Masterminds/squirrel v1.5.0 h1:JukIZisrUXadA9pl3rMkjhiamxiB0cXiu+HGp/Y8cY8= +github.com/Masterminds/squirrel v1.5.0/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= +github.com/ajg/form v0.0.0-20160822230020-523a5da1a92f h1:zvClvFQwU++UpIUBGC8YmDlfhUrweEy1R1Fj1gu5iIM= +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= +github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-sdk-go v1.44.321 h1:iXwFLxWjZPjYqjPq0EcCs46xX7oDLEELte1+BzgpKk8= +github.com/aws/aws-sdk-go v1.44.321/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= +github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gavv/monotime v0.0.0-20161010190848-47d58efa6955 h1:gmtGRvSexPU4B1T/yYo0sLOKzER1YT+b4kPxPpm0Ty4= +github.com/getsentry/sentry-go v0.23.0 h1:dn+QRCeJv4pPt9OjVXiMcGIBIefaTJPw/h0bZWO05nE= +github.com/getsentry/sentry-go v0.23.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= +github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= +github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= +github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= +github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/gobuffalo/logger v1.0.6 h1:nnZNpxYo0zx+Aj9RfMPBm+x9zAU2OayFh/xrAWi34HU= +github.com/gobuffalo/packd v1.0.1 h1:U2wXfRr4E9DH8IdsDLlRFwTZTK7hLfq9qT/QHXGVe/0= +github.com/gobuffalo/packr v1.12.1 h1:+5u3rqgdhswdYXhrX6DHaO7BM4P8oxrbvgZm9H1cRI4= +github.com/gobuffalo/packr/v2 v2.8.3 h1:xE1yzvnO56cUC0sTpKR3DIbxZgB54AftTFMhB2XEWlY= +github.com/gocarina/gocsv v0.0.0-20230616125104-99d496ca653d h1:KbPOUXFUDJxwZ04vbmDOc3yuruGvVO+LOa7cVER3yWw= +github.com/gocarina/gocsv v0.0.0-20230616125104-99d496ca653d/go.mod h1:5YoVOkjYAQumqlV356Hj3xeYh4BdZuLE0/nRkf2NKkI= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-querystring v0.0.0-20160401233042-9235644dd9e5 h1:oERTZ1buOUYlpmKaqlO5fYmz8cZ1rYu5DieJzF4ZVmU= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= +github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= +github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jarcoal/httpmock v0.0.0-20161210151336-4442edb3db31 h1:Aw95BEvxJ3K6o9GGv5ppCd1P8hkeIeEJ30FO+OhOJpM= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/karrick/godirwalk v1.16.1 h1:DynhcF+bztK8gooS0+NDJFrdNZjJ3gzVzC545UNA9iw= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw1HU4= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/localtunnel/go-localtunnel v0.0.0-20170326223115-8a804488f275 h1:IZycmTpoUtQK3PD60UYBwjaCUHUP7cML494ao9/O8+Q= +github.com/localtunnel/go-localtunnel v0.0.0-20170326223115-8a804488f275/go.mod h1:zt6UU74K6Z6oMOYJbJzYpYucqdcQwSMPBEdSvGiaUMw= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= +github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= +github.com/manucorporat/sse v0.0.0-20160126180136-ee05b128a739 h1:ykXz+pRRTibcSjG1yRhpdSHInF8yZY/mfn+Rz2Nd1rE= +github.com/manucorporat/sse v0.0.0-20160126180136-ee05b128a739/go.mod h1:zUx1mhth20V3VKgL5jbd1BSQcW4Fy6Qs4PZvQwRFwzM= +github.com/markbates/errx v1.1.0 h1:QDFeR+UP95dO12JgW+tgi2UVfo0V8YBHiUIOaeBPiEI= +github.com/markbates/oncer v1.0.0 h1:E83IaVAHygyndzPimgUYJjbshhDTALZyXxvk9FOlQRY= +github.com/markbates/safe v1.0.1 h1:yjZkbvRM6IzKj9tlu/zMJLS0n/V351OZWRnF3QfaUxI= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= +github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/moul/http2curl v0.0.0-20161031194548-4e24498b31db h1:eZgFHVkk9uOTaOQLC6tgjkzdp7Ays8eEVecBcfHZlJQ= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nyaruka/phonenumbers v1.1.8 h1:mjFu85FeoH2Wy18aOMUvxqi1GgAqiQSJsa/cCC5yu2s= +github.com/nyaruka/phonenumbers v1.1.8/go.mod h1:DC7jZd321FqUe+qWSNcHi10tyIyGNXGcNbfkPvdp1Vs= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= +github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY= +github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8= +github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= +github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= +github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= +github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= +github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= +github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= +github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rubenv/sql-migrate v1.5.2 h1:bMDqOnrJVV/6JQgQ/MxOpU+AdO8uzYYA/TxFUBzFtS0= +github.com/rubenv/sql-migrate v1.5.2/go.mod h1:H38GW8Vqf8F0Su5XignRyaRcbXbJunSWxs+kmzlg0Is= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/segmentio/go-loggly v0.5.1-0.20171222203950-eb91657e62b2 h1:S4OC0+OBKz6mJnzuHioeEat74PuQ4Sgvbf8eus695sc= +github.com/segmentio/go-loggly v0.5.1-0.20171222203950-eb91657e62b2/go.mod h1:8zLRYR5npGjaOXgPSKat5+oOh+UHd8OdbS18iqX9F6Y= +github.com/sergi/go-diff v0.0.0-20161205080420-83532ca1c1ca h1:oR/RycYTFTVXzND5r4FdsvbnBn0HJXSVeNAnwaTXRwk= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM= +github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ= +github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= +github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmqnqMYADFk= +github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.16.0 h1:rGGH0XDZhdUOryiDWjmIvUSWpbNqisK8Wk0Vyefw8hc= +github.com/spf13/viper v1.16.0/go.mod h1:yg78JgCJcbrQOvV9YLXgkLaZqUidkY9K+Dd1FofRzQg= +github.com/stellar/go v0.0.0-20230810175703-9c94bc588b15 h1:snRtfXX7WGO3frwMk6KtAJzLCRX9t48xDx0PX6tUbXg= +github.com/stellar/go v0.0.0-20230810175703-9c94bc588b15/go.mod h1:iTkyf5zUHlaIjZjyxaLLXLv+YHqg3etsqn8AOQ+DvG8= +github.com/stellar/go-xdr v0.0.0-20211103144802-8017fc4bdfee h1:fbVs0xmXpBvVS4GBeiRmAE3Le70ofAqFMch1GTiq/e8= +github.com/stellar/go-xdr v0.0.0-20211103144802-8017fc4bdfee/go.mod h1:yoxyU/M8nl9LKeWIoBrbDPQ7Cy+4jxRcWcOayZ4BMps= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.1 h1:4VhoImhV/Bm0ToFkXFi8hXNXwpDRZ/ynw3amt82mzq0= +github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdrifcy0= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= +github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= +github.com/twilio/twilio-go v1.11.0 h1:ixO2DfAV4c0Yza0Tom5F5ZZB8WUbigiFc9wD84vbYnc= +github.com/twilio/twilio-go v1.11.0/go.mod h1:tdnfQ5TjbewoAu4lf9bMsGvfuJ/QU9gYuv9yx3TSIXU= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/fasthttp v1.40.0 h1:CRq/00MfruPGFLTQKY8b+8SfdK60TxNztjRMnH0t1Yc= +github.com/xdrpp/goxdr v0.1.1 h1:E1B2c6E8eYhOVyd7yEpOyopzTPirUeF6mVOfXfGyJyc= +github.com/xeipuuv/gojsonpointer v0.0.0-20151027082146-e0fe6f683076 h1:KM4T3G70MiR+JtqplcYkNVoNz7pDwYaBxWBXQK804So= +github.com/xeipuuv/gojsonreference v0.0.0-20150808065054-e02fc20de94c h1:XZWnr3bsDQWAZg4Ne+cPoXRPILrNlPNQfxBuwLl43is= +github.com/xeipuuv/gojsonschema v0.0.0-20161231055540-f06f290571ce h1:cVSRGH8cOveJNwFEEZLXtB+XMnRqKLjUP6V/ZFYQCXI= +github.com/yalp/jsonpath v0.0.0-20150812003900-31a79c7593bb h1:06WAhQa+mYv7BiOk13B/ywyTlkoE/S7uu6TBKU6FHnE= +github.com/yudai/gojsondiff v0.0.0-20170107030110-7b1b7adf999d h1:yJIizrfO599ot2kQ6Af1enICnwBD3XoxgX3MrMwot2M= +github.com/yudai/golcs v0.0.0-20150405163532-d1c525dea8ce h1:888GrqRxabUce7lj4OaoShPxodm3kXOMpSa85wdYzfY= +github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= +go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= +golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= +golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230810033253-352e893a4cad h1:g0bG7Z4uG+OgH2QDODnjp6ggkk1bJDsINcuWmJN1iJU= +golang.org/x/exp v0.0.0-20230810033253-352e893a4cad/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= +golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= +golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= +google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= +google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= +google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= +google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE= +google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= +google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= +google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= +google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= +google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/gavv/httpexpect.v1 v1.0.0-20170111145843-40724cf1e4a0 h1:r5ptJ1tBxVAeqw4CrYWhXIMr0SybY3CDHuIbCg5CFVw= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tylerb/graceful.v1 v1.2.15 h1:1JmOyhKqAyX3BgTXMI84LwT6FOJ4tP2N9e2kwTCM0nQ= +gopkg.in/tylerb/graceful.v1 v1.2.15/go.mod h1:yBhekWvR20ACXVObSSdD3u6S9DeSylanL2PAbAC/uJ8= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/gomod.sh b/gomod.sh new file mode 100755 index 000000000..64210da0a --- /dev/null +++ b/gomod.sh @@ -0,0 +1,8 @@ +#! /bin/bash +set -e + +go mod tidy +git diff --exit-code -- go.mod || (echo "Go file go.mod is dirty, update the file with 'go mod tidy' locally." && exit 1) +git diff --exit-code -- go.sum || (echo "Go file go.sum is dirty, update the file with 'go mod tidy' locally." && exit 1) +diff -u go.list <(go list -m all) || (echo "Go dependencies have changed, update the go.list file with 'go list -m all > go.list' locally." && exit 1) +go mod verify || (echo "One or more Go dependencies failed verification. Either a version is no longer available, or the author or someone else has modified the version so it no longer points to the same code." && exit 1) diff --git a/helmchart/docs/README.md b/helmchart/docs/README.md new file mode 100644 index 000000000..2731f3ba3 --- /dev/null +++ b/helmchart/docs/README.md @@ -0,0 +1,138 @@ +# Stellar Disbursement Platform - Helm Chart + +## Table of Contents + +- [Stellar Disbursement Platform - Helm Chart](#stellar-disbursement-platform---helm-chart) + - [Table of Contents](#table-of-contents) + - [Installation](#installation) + - [Local Development Cheatsheet](#local-development-cheatsheet) + - [Using Ingress and a Local TLS Certificate](#using-ingress-and-a-local-tls-certificate) + - [Creating local secrets for the deployments](#creating-local-secrets-for-the-deployments) + +## Installation + +```bash +helm install {release-name-here} ./sdp +``` + +Likewise, to uninstall it you can run: + +```bash +helm uninstall {release-name-here} +``` + +And if you want to upgrade a version that's currently deployed, you can do the following: + +```bash +helm upgrade {release-name-here} ./sdp +``` + +## Local Development Cheatsheet + +For debugging purposes, it's sometimes useful to render the templates locally. To do so, you can execute: + +```bash +helm template --release-name {release-name-here} -f values.yaml --debug . +``` + +If you want to deploy this locally, you can enable kubernetes on docker-desktop. Some useful commands: + +### Using Ingress and a Local TLS Certificate + +To create a self-signed TLS certificate for local development purposes with both `sdp.localhost.com` and `ap.localhost.com` as endpoints, follow these steps (you only need to do it once): + +1. Install `ingress-nginx`: + +```bash +helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx +helm repo update +kubectl create namespace ingress-nginx +helm install ingress-nginx ingress-nginx/ingress-nginx --namespace=ingress-nginx +``` + + +2. Create a `openssl.cnf` configuration file: + +Create a new file named `openssl.cnf` with the following content, which includes both endpoints as subject alternative names (SANs): + +```bash +[req] +distinguished_name = req_distinguished_name +x509_extensions = v3_req +prompt = no + +[req_distinguished_name] +CN = localhost + +[v3_req] +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +subjectAltName = @alt_names + +[alt_names] +DNS.1 = sdp.localhost.com +DNS.2 = ap.localhost.com +``` + +3. Generate the self-signed certificate and key: +Run the following command to generate a self-signed certificate (`tls.crt`) and private key (`tls.key`) using the configuration file: + +```bash +openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout tls.key -out tls.crt -config openssl.cnf +``` + +This command will create a certificate valid for 365 days with a 2048-bit RSA key. + +4. Create a Kubernetes Secret with the TLS certificate and key: + +```bash +kubectl create secret tls stellar-disbursement-platform-backend-tls-cert --key tls.key --cert tls.crt --namespace=stellar-disbursement-platform +``` + +Replace `myapp-tls` with a descriptive name for your secret. + +5. Update your Ingress configuration to use the TLS secret: + +Add the `tls` section to your Ingress manifest, referencing the secret you created in step 3: + +```yaml +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: stellar-disbursement-platform-backend + namespace: stellar-disbursement-platform + # ... other metadata ... +spec: + ingressClassName: nginx # <---- This is important! + tls: + - hosts: + - sdp.localhost.com + - ap.localhost.com + secretName: myapp-tls + rules: + # ... existing rules ... +``` + +Reapply the Ingress manifest: + +```bash +kubectl apply -f .yaml +``` + +6. Update your `/etc/hosts` file by adding: + +```bash +# SDP + Anchor Platform: +127.0.0.1 sdp.localhost.com +127.0.0.1 ap.localhost.com +``` + +πŸŽ‰ Now, you should be able to access your services at `https://sdp.localhost.com` and `https://ap.localhost.com`. Keep in mind that browsers will display a security warning when accessing your site due to the use of a self-signed certificate. You can add an exception for your local domains to trust the self-signed certificate. + +### Creating local secrets for the deployments + +To create the secrets containing the env vars required for the deployments, simply create a .env file with the desired values, then run: + +```bash +kubectl create secret generic --from-env-file= --namespace= +``` diff --git a/helmchart/docs/openssl.cnf b/helmchart/docs/openssl.cnf new file mode 100644 index 000000000..188ae58d3 --- /dev/null +++ b/helmchart/docs/openssl.cnf @@ -0,0 +1,16 @@ +[req] +distinguished_name = req_distinguished_name +x509_extensions = v3_req +prompt = no + +[req_distinguished_name] +CN = localhost + +[v3_req] +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +subjectAltName = @alt_names + +[alt_names] +DNS.1 = sdp.localhost.com +DNS.2 = ap.localhost.com \ No newline at end of file diff --git a/helmchart/docs/tls.crt b/helmchart/docs/tls.crt new file mode 100644 index 000000000..bf29bc0b3 --- /dev/null +++ b/helmchart/docs/tls.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/zCCAeegAwIBAgIJAO/mofddJS20MA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV +BAMMCWxvY2FsaG9zdDAeFw0yMzA0MjEyMTE1MDJaFw0yODA0MTkyMTE1MDJaMBQx +EjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBALyckqZOspDYpyv0e54qesd1LTaCbuazrKcnjJNm9pgxe8hUJK3BjO1Z5miR +OFrg+Vy2sUivPZovLf8tJwt+MVU0d2GBxYmEExvQN8bjOT4GdRTFQW8msxmEazaM +Og9BaDIDDQoAFJgn9jQbRT+QUa+dnA+LJSOA6dujg0h+X+mXLBHTjzp7o2sW9+gP +HwQmHUngRdEZa6ABfwT4WEP2HR/IILWr68T8YIEtsKWMyBDuxSpmF5Xwd+oEqUjw +wG14rlInmbfK1Q2mK6iy00E9QsXDKRa61w9ysrLACu07goRUPz1Q/OpuTRmZHlK9 +cQMo65uQxjK+4czt1GYJKTUxygcCAwEAAaNUMFIwCwYDVR0PBAQDAgXgMBMGA1Ud +JQQMMAoGCCsGAQUFBwMBMC4GA1UdEQQnMCWCEXNkcC5sb2NhbGhvc3QuY29tghBh +cC5sb2NhbGhvc3QuY29tMA0GCSqGSIb3DQEBCwUAA4IBAQBaMZ0m/1B56Z29/Y9E +XieLWBP1iA4gTz82OeYfh46fRg1zIb/qC//A98U7BAaLMOgDA5ZfaUmGvADJkcVL +/cY7tdrWlKdUhiE0nspsBMaJqwOSN/J+G1OCYzBkDdndCvi8+NisRJpKZ5c+rnE5 +e/7zNmANBy0dy6q3Oyncq8BFKFkxYV7ZDejZln0FLsqm96DJMhhZv8976vA3YEtz +phkSq+c8nQUXadqYwZoxHIpPlxS/4SFBXhghNfQ3IinhtIQ9arY8hnUWnCbGPWig +vgwofeZmPbX1p4l20N6dR7VnH5AsWCL0P5fkxfZH5EOotyeTqB5SGuG+WTPEDSNL +nmhE +-----END CERTIFICATE----- diff --git a/helmchart/docs/tls.key b/helmchart/docs/tls.key new file mode 100644 index 000000000..56051e38f --- /dev/null +++ b/helmchart/docs/tls.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC8nJKmTrKQ2Kcr +9HueKnrHdS02gm7ms6ynJ4yTZvaYMXvIVCStwYztWeZokTha4PlctrFIrz2aLy3/ +LScLfjFVNHdhgcWJhBMb0DfG4zk+BnUUxUFvJrMZhGs2jDoPQWgyAw0KABSYJ/Y0 +G0U/kFGvnZwPiyUjgOnbo4NIfl/plywR0486e6NrFvfoDx8EJh1J4EXRGWugAX8E ++FhD9h0fyCC1q+vE/GCBLbCljMgQ7sUqZheV8HfqBKlI8MBteK5SJ5m3ytUNpiuo +stNBPULFwykWutcPcrKywArtO4KEVD89UPzqbk0ZmR5SvXEDKOubkMYyvuHM7dRm +CSk1McoHAgMBAAECggEBAKbKgzkEkQ5cOQE603KcODVYXuI8WBw1ubGb3MmhHOKb +p84Q8tqhNaHThVjlSwO6vWmKuZ4AMia+IBvvbv7P2opxujyFCQ3BuTA4YorD96Pt +C+6RPXswquMe4by8Jr/E5IKNhiNYzN19QVD06Lj8Q/BcHVz1fKM35hZwM7GQ5/pl +Su5dtJ6tGqylfqb8kEO3j0pHUZH83zJ5JpKofe1i8ZvMfPXGShO339m4ecd/Rf5R +9yJ4qVdCGhXtBdz/O+7bxccVNJIVkxCNAq9umfCLIvUT+EpDJmV9rycdWwRjiHjE +eluflft9KEsga9Lqb/COEdRc6EifIhbvCgDbKvuVvTkCgYEA5fT1R/YA0jgg06B5 +qEEV2bnYLBJBE00qJawxahaZjvr74p7L5dPqDlivmhwDXF4Q+QBWRzwl8EL5p6uW +4F9MQgZRgEwsP1/a3U9Dw5mbqZ6fn/rkW2wLwgVYIYjITGHPsyEy1LOj19JYCyUQ +9DmVBJCn7wWO/GILr73Iw3Bcc4UCgYEA0fjoxzcLNcxGDEP3FnlilCs0tfNGimQp +/HU5Q1wCCZ+04GK9YAxsV2PT3kVUjH+UByZiFL0Or6MWz1+DlBiwFHOdDEqHiPPL +JD4RYQzfo+vOmYJETQ0sDkYMrqhhSppYCxLUIXEtU0cZ0b1/UCxr/cEJf6R9bbFg +kZhd0SKKnxsCgYBZEVAP11DqG6NbVMTKTqtP8ZOxPkDGYRT6En/xP1+q6bu2Qxtm +oXX+qIsbfc4vcJ/SUjcY0EtBjC92qhd+QGshB6F5uAdLZK05GwJ8OHr6b94T8PGS +F39WXwuLsZcjPp9cGne9uvazGV3Qs0Kl1cfKRN1GzzhauP8dyryANn0YoQKBgBEo +IIUep0jXDyYza34nnvlyalUvsqTeOFwLjAlH/Fai+RmYl9bATR365zXzPkxYpFTN +OxhstkV9swBw0oSIW+Lf64Y0lMyI9yFX/P2MGr3/J5t9fG07VU05RhIDaie5YtZM +zI6K++QhHCf6LuvzJUPPwSHv49vRsY1UAN50zxTfAoGAdA02YPTngC3SIVsbuBbq +YRlVhNF8xNEYGlWZOjC/jqFa7PFdWXMEb5u+BpVe5tH+Nv7aVZXCig1pIUXjd8gZ +iwjv/Ti+gj8sG2T7EfW9FggJ7D0aYnAubteJZLERs5km7Hu3KGRnJvXKrbNjfGs7 +7OGE2Djmmpib1gFem28X8nk= +-----END PRIVATE KEY----- diff --git a/helmchart/sdp/.helmignore b/helmchart/sdp/.helmignore new file mode 100644 index 000000000..0e8a0eb36 --- /dev/null +++ b/helmchart/sdp/.helmignore @@ -0,0 +1,23 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/helmchart/sdp/Chart.yaml b/helmchart/sdp/Chart.yaml new file mode 100644 index 000000000..0bf3cc3a7 --- /dev/null +++ b/helmchart/sdp/Chart.yaml @@ -0,0 +1,15 @@ +apiVersion: v2 +name: sdp +description: A Helm chart for the Stellar Disbursement Platform Backend (A.K.A. `sdp`) + +# A chart can be either an 'application' or a 'library' chart. +type: application + +# This is the chart version. This version number should be incremented each time you make changes +# to the chart and its templates, including the app version. +# Versions are expected to follow Semantic Versioning (https://semver.org/) +version: 0.2.0 + +# This is the version number of the application being deployed. Should be the +# same as the one used in ./main.go#Version. +appVersion: "0.2.0" diff --git a/helmchart/sdp/templates/01.1-configmap-sdp.yaml b/helmchart/sdp/templates/01.1-configmap-sdp.yaml new file mode 100644 index 000000000..bc467272e --- /dev/null +++ b/helmchart/sdp/templates/01.1-configmap-sdp.yaml @@ -0,0 +1,30 @@ +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "sdp.fullname" . }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + + {{- if .Values.configMap.annotations }} + annotations: + {{- toYaml .Values.configMap.annotations | nindent 4 }} + {{- end }} + +{{- if .Values.configMap.data }} +data: + {{- if eq (include "isPubnet" .) "true" }} + NETWORK_PASSPHRASE: "Public Global Stellar Network ; September 2015" + HORIZON_URL: "https://horizon.stellar.org" + {{- else }} + NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + HORIZON_URL: "https://horizon-testnet.stellar.org" + {{- end }} + BASE_URL: {{ include "sdp.schema" . }}://{{ include "sdp.domain" . }} + PORT: {{ include "sdp.port" . | quote }} + METRICS_PORT: {{ include "sdp.metricsPort" . | quote }} + ANCHOR_PLATFORM_BASE_SEP_URL: {{ include "sdp.ap.schema" . }}://{{ include "sdp.ap.domain" . }} + ANCHOR_PLATFORM_BASE_PLATFORM_URL: {{ include "sdp.ap.platformServiceAddress" . }} + {{- tpl (toYaml .Values.configMap.data | nindent 2) . }} +{{- end }} diff --git a/helmchart/sdp/templates/01.2-configmap-ap.yaml b/helmchart/sdp/templates/01.2-configmap-ap.yaml new file mode 100644 index 000000000..ee19d8f37 --- /dev/null +++ b/helmchart/sdp/templates/01.2-configmap-ap.yaml @@ -0,0 +1,44 @@ +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "sdp.fullname" . }}-ap + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labelsWithSuffix" (list . "-ap") | nindent 4 }} + + {{- if .Values.anchorPlatform.configMap.annotations }} + annotations: + {{- toYaml .Values.anchorPlatform.configMap.annotations | nindent 4 }} + {{- end }} + +{{- if .Values.anchorPlatform.configMap.data }} +data: + # if {{ include "isPubnet" . }} is true, then the network is set to PUBNET, else it's all TESTNET + {{- if eq (include "isPubnet" .) "true" }} + STELLAR_NETWORK_NETWORK: "PUBNET" + STELLAR_NETWORK_NETWORK_PASSPHRASE: "Public Global Stellar Network ; September 2015" + STELLAR_NETWORK_HORIZON_URL: "https://horizon.stellar.org" + {{- else }} + STELLAR_NETWORK_NETWORK: "TESTNET" + STELLAR_NETWORK_NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + STELLAR_NETWORK_HORIZON_URL: "https://horizon-testnet.stellar.org" + {{- end }} + HOST_URL: {{ include "sdp.ap.schema" . }}://{{ include "sdp.ap.domain" . }} + SEP_SERVER_PORT: {{ include "sdp.ap.sepPort" . | quote }} + CALLBACK_API_BASE_URL: 'http://{{ include "sdp.fullname" . }}.{{ .Release.Namespace }}:{{ include "sdp.port" . }}' + DATA_TYPE: postgres + SEP1_ENABLED: "true" + SEP1_TOML_TYPE: url + SEP1_TOML_VALUE: 'http://{{ include "sdp.fullname" . }}.{{ .Release.Namespace }}:{{ include "sdp.port" . }}/.well-known/stellar.toml' + SEP10_ENABLED: "true" + SEP10_HOME_DOMAIN: {{ include "sdp.ap.domain" . }} + SEP24_ENABLED: "true" + SEP24_INTERACTIVE_URL_JWT_EXPIRATION: "1800" # 1800 seconds is 30 minutes + ASSETS_TYPE: json + SEP24_INTERACTIVE_URL_BASE_URL: {{ include "sdp.schema" . }}://{{ include "sdp.domain" . }}/wallet-registration/start + SEP24_MORE_INFO_URL_BASE_URL: {{ include "sdp.schema" . }}://{{ include "sdp.domain" . }}/wallet-registration/start + CALLBACK_API_AUTH_TYPE: none # TODO: update to jwt later + PLATFORM_SERVER_AUTH_TYPE: JWT + {{- tpl (toYaml .Values.anchorPlatform.configMap.data | nindent 2) . }} +{{- end }} diff --git a/helmchart/sdp/templates/01.3-configmap-tss.yaml b/helmchart/sdp/templates/01.3-configmap-tss.yaml new file mode 100644 index 000000000..5e8df9079 --- /dev/null +++ b/helmchart/sdp/templates/01.3-configmap-tss.yaml @@ -0,0 +1,27 @@ +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "sdp.fullname" . }}-tss + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + + {{- if .Values.configMap.annotations }} + annotations: + {{- toYaml .Values.configMap.annotations | nindent 4 }} + {{- end }} + +data: + {{- if eq (include "isPubnet" .) "true" }} + NETWORK_PASSPHRASE: "Public Global Stellar Network ; September 2015" + HORIZON_URL: "https://horizon.stellar.org" + {{- else }} + NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + HORIZON_URL: "https://horizon-testnet.stellar.org" + {{- end }} + NUM_CHANNEL_ACCOUNTS: "{{ .Values.router.tss.numChannelAccounts }}" + MAX_BASE_FEE: "{{ .Values.router.tss.maxBaseFee }}" + MOCK: "{{ .Values.router.tss.mock }}" + TSS_METRICS_PORT: {{ include "tss.metricsPort" . | quote }} + TSS_METRICS_TYPE: {{ .Values.router.tss.metricsType | default "TSS_PROMETHEUS" | quote }} diff --git a/helmchart/sdp/templates/02-deployment-ap.yaml b/helmchart/sdp/templates/02-deployment-ap.yaml new file mode 100644 index 000000000..cd673bd70 --- /dev/null +++ b/helmchart/sdp/templates/02-deployment-ap.yaml @@ -0,0 +1,126 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "sdp.fullname" . }}-ap + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labelsWithSuffix" (list . "-ap") | nindent 4 }} + {{- if .Values.anchorPlatform.deployment.annotations }} + annotations: + {{- tpl (toYaml .Values.anchorPlatform.deployment.annotations) . | nindent 4 }} + {{- end }} +spec: + {{- if not .Values.autoscaling.enabled }} + replicas: {{ .Values.replicaCount }} + {{- end }} + selector: + matchLabels: + {{- include "sdp.selectorLabelsWithSuffix" (list . "-ap") | nindent 6 }} + + {{- if .Values.anchorPlatform.deployment.strategy }} + strategy: + {{- toYaml .Values.anchorPlatform.deployment.strategy | nindent 4 }} + {{- end }} + + template: + metadata: + {{- if .Values.anchorPlatform.deployment.podAnnotations }} + annotations: + {{- tpl (toYaml .Values.anchorPlatform.deployment.podAnnotations) . | nindent 8 }} + {{- end }} + labels: + {{- include "sdp.selectorLabelsWithSuffix" (list . "-ap") | nindent 8 }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + # {{- if .Values.serviceAccount.name }} + # serviceAccountName: {{ tpl .Values.serviceAccount.name $ }} + # {{- end }} + securityContext: + {{- tpl (toYaml .Values.anchorPlatform.deployment.podSecurityContext) . | nindent 8 }} + + containers: + # ============================= Anchor Platform: ============================= + - name: {{ .Chart.Name }}-ap + securityContext: + {{- tpl (toYaml .Values.anchorPlatform.deployment.securityContext) . | nindent 12 }} + image: "stellar/anchor-platform:2.1.3" + imagePullPolicy: "IfNotPresent" + {{- if .Values.ephemeralDatabase }} + env: + - name: DATA_TYPE + value: 'postgres' + - name: DATA_SERVER + value: '{{ include "sdp.fullname" . }}-psql.{{ .Release.Namespace }}.svc.cluster.local:5433' + - name: DATA_DATABASE + value: 'postgres-ap' + - name: SECRET_DATA_USERNAME + value: 'postgres' + - name: SECRET_DATA_PASSWORD + value: 'postgres' + - name: SDP_IMAGE_TAG # This env is used to force the AP to be redeployed every time the SDP is deployed. This is used to force the SDP to re-fetch the toml file and assets to ensure the latest ones are used. + value: {{ .Values.image.tag }} + {{- end }} + args: + - "--sep-server" + - "--platform-server" + ports: + - name: ap-sep + containerPort: {{ include "sdp.ap.sepPort" . }} + protocol: TCP + - name: ap-platform + containerPort: {{ include "sdp.ap.platformPort" . }} + protocol: TCP + - name: ap-metrics + containerPort: {{ include "sdp.ap.metricsPort" . }} + protocol: TCP + livenessProbe: + httpGet: + path: /health?checks=config + port: ap-sep + initialDelaySeconds: 60 + periodSeconds: 15 + failureThreshold: 10 + readinessProbe: + httpGet: + path: /health?checks=config + port: ap-sep + initialDelaySeconds: 60 + periodSeconds: 15 + failureThreshold: 10 + startupProbe: + httpGet: + path: /health?checks=config + port: ap-sep + initialDelaySeconds: 60 + periodSeconds: 15 + failureThreshold: 10 + + {{- if .Values.anchorPlatform.deployment.resources }} + resources: + {{- tpl (toYaml .Values.anchorPlatform.deployment.resources) . | nindent 12 }} + {{- end }} + + envFrom: + - configMapRef: + name: {{ include "sdp.fullname" . }}-ap + + {{- if .Values.anchorPlatform.secretName }} + - secretRef: + name: {{ .Values.anchorPlatform.secretName }} + {{ end }} + + {{- with .Values.deployment.nodeSelector }} + nodeSelector: + {{- tpl (toYaml .) . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.affinity }} + affinity: + {{- tpl (toYaml .) . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.tolerations }} + tolerations: + {{- tpl (toYaml .) . | nindent 8 }} + {{- end }} diff --git a/helmchart/sdp/templates/02-deployment-sdp.yaml b/helmchart/sdp/templates/02-deployment-sdp.yaml new file mode 100644 index 000000000..faeaff3cd --- /dev/null +++ b/helmchart/sdp/templates/02-deployment-sdp.yaml @@ -0,0 +1,141 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "sdp.fullname" . }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + {{- if .Values.deployment.annotations }} + annotations: + {{- tpl (toYaml .Values.deployment.annotations) . | nindent 4 }} + {{- end }} +spec: + {{- if not .Values.autoscaling.enabled }} + replicas: {{ .Values.replicaCount }} + {{- end }} + selector: + matchLabels: + {{- include "sdp.selectorLabels" . | nindent 6 }} + + {{- if .Values.deployment.strategy }} + strategy: + {{- toYaml .Values.deployment.strategy | nindent 4 }} + {{- end }} + + template: + metadata: + {{- if .Values.deployment.podAnnotations }} + annotations: + {{- tpl (toYaml .Values.deployment.podAnnotations) . | nindent 8 }} + {{- end }} + labels: + {{- include "sdp.selectorLabels" . | nindent 8 }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- if .Values.serviceAccount.name }} + serviceAccountName: {{ tpl .Values.serviceAccount.name $ }} + {{- end }} + securityContext: + {{- toYaml .Values.deployment.podSecurityContext | nindent 8 }} + + initContainers: + # ============================= SDP Migrations: ============================= + - name: db-migrations + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + envFrom: + - configMapRef: + name: {{ include "sdp.fullname" . }} + + {{- if .Values.secretName }} + - secretRef: + name: {{ .Values.secretName }} + {{ end }} + {{- if .Values.ephemeralDatabase }} + env: + - name: DATABASE_URL + value: 'postgres://postgres:postgres@{{ include "sdp.fullname" . }}-psql.{{ .Release.Namespace }}.svc.cluster.local:5432/postgres-sdp?sslmode=disable' + {{- end }} + command: + - sh + - -c + - | + ./stellar-disbursement-platform db migrate up && + ./stellar-disbursement-platform db auth migrate up && + ./stellar-disbursement-platform db setup-for-network && + ./stellar-disbursement-platform channel-accounts verify --delete-invalid-accounts && + ./stellar-disbursement-platform channel-accounts ensure --num-channel-accounts-ensure {{ .Values.configMap.data.NUM_CHANNEL_ACCOUNTS | default 1 }} + + containers: + # ============================= Stellar Disbursement Platform: ============================= + - name: {{ .Chart.Name }} + securityContext: + {{- toYaml .Values.deployment.securityContext | nindent 12 }} + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + {{- if .Values.ephemeralDatabase }} + env: + - name: DATABASE_URL + value: 'postgres://postgres:postgres@{{ include "sdp.fullname" . }}-psql.{{ .Release.Namespace }}.svc.cluster.local:5432/postgres-sdp?sslmode=disable' + {{- end }} + args: + - "serve" + ports: + - name: http + containerPort: {{ include "sdp.port" . }} + protocol: TCP + - name: metrics + containerPort: {{ include "sdp.metricsPort" . }} + protocol: TCP + livenessProbe: + httpGet: + path: /health + port: http + initialDelaySeconds: 5 + periodSeconds: 15 + failureThreshold: 10 + readinessProbe: + httpGet: + path: /health + port: http + initialDelaySeconds: 5 + periodSeconds: 15 + failureThreshold: 10 + startupProbe: + httpGet: + path: /health + port: http + initialDelaySeconds: 5 + periodSeconds: 15 + failureThreshold: 10 + + {{- if .Values.resources }} + resources: + {{- toYaml .Values.resources | nindent 12 }} + {{- end }} + + envFrom: + - configMapRef: + name: {{ include "sdp.fullname" . }} + + {{- if .Values.secretName }} + - secretRef: + name: {{ .Values.secretName }} + {{ end }} + + + {{- with .Values.deployment.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/helmchart/sdp/templates/02.1-ephemeral-postgres.yaml b/helmchart/sdp/templates/02.1-ephemeral-postgres.yaml new file mode 100644 index 000000000..bd513641d --- /dev/null +++ b/helmchart/sdp/templates/02.1-ephemeral-postgres.yaml @@ -0,0 +1,85 @@ +{{- if .Values.ephemeralDatabase -}} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "sdp.fullname" . }}-psql + namespace: {{ .Release.Namespace }} +spec: + replicas: 1 + selector: + matchLabels: + app: {{ include "sdp.fullname" . }}-psql + template: + metadata: + labels: + app: {{ include "sdp.fullname" . }}-psql + spec: + containers: + # =================== SDP Ephemeral Postgres DB =================== + - name: {{ include "sdp.fullname" . }}-psql-sdp + image: postgres:12-alpine + imagePullPolicy: "IfNotPresent" + ports: + - name: postgres-sdp + containerPort: 5432 # Exposes container port + protocol: TCP + env: + - name: POSTGRES_DB + value: "postgres-sdp" + - name: POSTGRES_USER + value: "postgres" + - name: POSTGRES_PASSWORD + value: "postgres" + - name: PGPORT + value: "5432" + volumeMounts: + - mountPath: /var/lib/postgresql/data-sdp + name: postgredb-sdp + + # =================== AP Ephemeral Postgres DB =================== + - name: {{ include "sdp.fullname" . }}-psql-ap + image: postgres:12-alpine + imagePullPolicy: "IfNotPresent" + ports: + - name: postgres-ap + containerPort: 5433 # Exposes container port + protocol: TCP + env: + - name: POSTGRES_DB + value: "postgres-ap" + - name: POSTGRES_USER + value: "postgres" + - name: POSTGRES_PASSWORD + value: "postgres" + - name: PGPORT + value: "5433" + volumeMounts: + - mountPath: /var/lib/postgresql/data-ap + name: postgredb-ap + + # =================== Volumes =================== + volumes: + - name: postgredb-sdp + - name: postgredb-ap +--- +apiVersion: v1 +kind: Service +metadata: + name: {{ include "sdp.fullname" . }}-psql + namespace: {{ .Release.Namespace }} +spec: + selector: + app: {{ include "sdp.fullname" . }}-psql + ports: + # =================== SDP =================== + - port: 5432 + targetPort: postgres-sdp + name: postgres-sdp + protocol: TCP + # =================== AP =================== + - port: 5433 + targetPort: postgres-ap + name: postgres-ap + protocol: TCP + type: ClusterIP +{{- end }} diff --git a/helmchart/sdp/templates/02.2-autoscaling.yaml b/helmchart/sdp/templates/02.2-autoscaling.yaml new file mode 100644 index 000000000..c6b1305c1 --- /dev/null +++ b/helmchart/sdp/templates/02.2-autoscaling.yaml @@ -0,0 +1,33 @@ +{{- if .Values.autoscaling.enabled }} +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: {{ include "sdp.fullname" . }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: {{ include "sdp.fullname" . }} + minReplicas: {{ .Values.autoscaling.minReplicas }} + maxReplicas: {{ .Values.autoscaling.maxReplicas }} + metrics: + {{- if .Values.autoscaling.targetCPUUtilizationPercentage }} + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }} + {{- end }} + {{- if .Values.autoscaling.targetMemoryUtilizationPercentage }} + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }} + {{- end }} +{{- end }} diff --git a/helmchart/sdp/templates/02.3-deployment-tss.yaml b/helmchart/sdp/templates/02.3-deployment-tss.yaml new file mode 100644 index 000000000..6ad8324c1 --- /dev/null +++ b/helmchart/sdp/templates/02.3-deployment-tss.yaml @@ -0,0 +1,90 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "sdp.fullname" . }}-tss + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + {{- if .Values.deployment.annotations }} + annotations: + {{- tpl (toYaml .Values.deployment.annotations) . | nindent 4 }} + {{- end }} +spec: + {{- if not .Values.autoscaling.enabled }} + replicas: {{ .Values.replicaCount }} + {{- end }} + selector: + matchLabels: + {{- include "sdp.selectorLabels" . | nindent 6 }} + + {{- if .Values.deployment.strategy }} + strategy: + {{- toYaml .Values.deployment.strategy | nindent 4 }} + {{- end }} + + template: + metadata: + {{- if .Values.deployment.tssPodAnnotations }} + annotations: + {{- tpl (toYaml .Values.deployment.tssPodAnnotations) . | nindent 8 }} + {{- end }} + labels: + {{- include "sdp.selectorLabels" . | nindent 8 }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- if .Values.serviceAccount.name }} + serviceAccountName: {{ tpl .Values.serviceAccount.name $ }} + {{- end }} + securityContext: + {{- toYaml .Values.deployment.podSecurityContext | nindent 8 }} + + containers: + # ============================= Stellar Disbursement Platform: ============================= + - name: {{ .Chart.Name }}-tss + securityContext: + {{- toYaml .Values.deployment.securityContext | nindent 12 }} + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + {{- if .Values.ephemeralDatabase }} + env: + - name: DATABASE_URL + value: 'postgres://postgres:postgres@{{ include "sdp.fullname" . }}-psql.{{ .Release.Namespace }}.svc.cluster.local:5432/postgres-sdp?sslmode=disable' + {{- end }} + args: + - "tss" + ports: + - name: http + containerPort: {{ include "tss.port" . }} + protocol: TCP + - name: metrics + containerPort: {{ include "tss.metricsPort" . }} + protocol: TCP + + {{- if .Values.resources }} + resources: + {{- toYaml .Values.resources | nindent 12 }} + {{- end }} + + envFrom: + - configMapRef: + name: {{ include "sdp.fullname" . }}-tss + + {{- if .Values.secretName }} + - secretRef: + name: {{ .Values.secretName }} + {{ end }} + {{- with .Values.deployment.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/helmchart/sdp/templates/03-service-ap.yaml b/helmchart/sdp/templates/03-service-ap.yaml new file mode 100644 index 000000000..777dd78c7 --- /dev/null +++ b/helmchart/sdp/templates/03-service-ap.yaml @@ -0,0 +1,21 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "sdp.fullname" . }}-ap + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labelsWithSuffix" (list . "-ap") | nindent 4 }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ include "sdp.ap.sepPort" . }} + targetPort: ap-sep + protocol: TCP + name: ap-sep + - port: {{ include "sdp.ap.platformPort" . }} + targetPort: ap-platform + protocol: TCP + name: ap-platform + + selector: + {{- include "sdp.selectorLabelsWithSuffix" (list . "-ap") | nindent 4 }} diff --git a/helmchart/sdp/templates/03-service-sdp.yaml b/helmchart/sdp/templates/03-service-sdp.yaml new file mode 100644 index 000000000..2d278b5d4 --- /dev/null +++ b/helmchart/sdp/templates/03-service-sdp.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "sdp.fullname" . }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ include "sdp.port" . }} + targetPort: http + protocol: TCP + name: http + + selector: + {{- include "sdp.selectorLabels" . | nindent 4 }} diff --git a/helmchart/sdp/templates/03.1-serviceaccount.yaml b/helmchart/sdp/templates/03.1-serviceaccount.yaml new file mode 100644 index 000000000..b111fc5e4 --- /dev/null +++ b/helmchart/sdp/templates/03.1-serviceaccount.yaml @@ -0,0 +1,13 @@ +{{- if .Values.serviceAccount.create -}} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ tpl .Values.serviceAccount.name $ }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + {{- with .Values.serviceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +{{- end }} diff --git a/helmchart/sdp/templates/04-ingress.yaml b/helmchart/sdp/templates/04-ingress.yaml new file mode 100644 index 000000000..862737584 --- /dev/null +++ b/helmchart/sdp/templates/04-ingress.yaml @@ -0,0 +1,56 @@ +{{- if .Values.ingress.enabled -}} +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: {{ include "sdp.fullname" . }} + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + annotations: + {{- toYaml .Values.ingress.annotations | nindent 4 }} + # This is a way to block the stellar.toml file from being served on the "sdp.domain": + nginx.ingress.kubernetes.io/server-snippet: | + location ~ /.well-known/stellar.toml { + if ($host = {{ include "sdp.domain" . | quote }}) { + return 404; + } + if ($host = {{ include "sdp.ap.domain" . | quote }}) { + proxy_pass {{ include "sdp.ap.sepServiceAddress" . }}; + } + } +spec: + {{- if .Values.ingress.className }} + ingressClassName: {{ .Values.ingress.className }} + {{- end }} + {{- if .Values.ingress.tls }} + tls: + {{- tpl (toYaml .Values.ingress.tls) . | nindent 4 }} + {{- end }} + rules: + - host: {{ include "sdp.domain" . | quote }} + http: + paths: + - path: / + pathType: Prefix + backend: + service: + name: {{ include "sdp.fullname" . }} + port: + number: {{ include "sdp.port" . }} + - host: {{ include "sdp.ap.domain" . | quote }} + http: + paths: + # Only enable the AP endpints that are needed for this application: + {{- $service_name := printf "%s-ap" (include "sdp.fullname" .) }} + {{- $service_sep_port := include "sdp.ap.sepPort" . }} + {{- $paths := list "/health" "/.well-known" "/auth" "/sep24" -}} + {{- range $path := $paths }} + - path: {{ $path }} + pathType: Prefix + backend: + service: + name: {{ $service_name }} + port: + number: {{ $service_sep_port }} + {{- end }} +{{- end }} diff --git a/helmchart/sdp/templates/NOTES.txt b/helmchart/sdp/templates/NOTES.txt new file mode 100644 index 000000000..45e180274 --- /dev/null +++ b/helmchart/sdp/templates/NOTES.txt @@ -0,0 +1,18 @@ +1. Get the application URL by running these commands: +{{- if .Values.ingress.enabled }} +{{- range $host := .Values.ingress.hosts }} + {{- range .paths }} + http{{ if $.Values.ingress.tls }}s{{ end }}://{{ $host.host }}{{ .path }} + {{- end }} +{{- end }} +{{- else if contains "NodePort" .Values.service.type }} + export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "sdp.fullname" . }}) && export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}") && echo http://$NODE_IP:$NODE_PORT +{{- else if contains "LoadBalancer" .Values.service.type }} + NOTE: It may take a few minutes for the LoadBalancer IP to be available. + You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "sdp.fullname" . }}' + export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "sdp.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}") + echo http://$SERVICE_IP:{{ include "sdp.port" . }} +{{- else if contains "ClusterIP" .Values.service.type }} + export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "sdp.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}") && export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}") && echo "Visit http://127.0.0.1:8080 to use your application" + kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT +{{- end }} diff --git a/helmchart/sdp/templates/_helpers.tpl b/helmchart/sdp/templates/_helpers.tpl new file mode 100644 index 000000000..155402a9f --- /dev/null +++ b/helmchart/sdp/templates/_helpers.tpl @@ -0,0 +1,178 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "sdp.name" -}} +{{- default .Chart.Name .Release.Name | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "sdp.fullname" -}} +{{- default .Chart.Name .Release.Name | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "sdp.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels with suffix +*/}} +{{- define "sdp.labelsWithSuffix" -}} +{{- $ctx := index . 0 -}} +{{- $suffix := index . 1 | default "" -}} +helm.sh/chart: {{ include "sdp.chart" $ctx }} +{{ include "sdp.selectorLabelsWithSuffix" (list $ctx $suffix) }} +{{- if $ctx.Chart.AppVersion }} +app.kubernetes.io/version: {{ $ctx.Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ $ctx.Release.Service }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "sdp.labels" -}} +helm.sh/chart: {{ include "sdp.chart" . }} +{{ include "sdp.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels with suffix +*/}} +{{- define "sdp.selectorLabelsWithSuffix" -}} +{{- $ctx := index . 0 -}} +{{- $suffix := index . 1 | default "" -}} +app.kubernetes.io/name: {{ include "sdp.name" $ctx }}{{ $suffix }} +app.kubernetes.io/instance: {{ $ctx.Release.Name }}{{ $suffix }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "sdp.selectorLabels" -}} +{{ include "sdp.selectorLabelsWithSuffix" (list . "") }} +{{- end }} + +{{/* +SDP domain +*/}} +{{- define "sdp.domain" -}} +{{- .Values.router.sdp.domain | default "localhost" }} +{{- end }} + +{{/* +SDP domain schema +*/}} +{{- define "sdp.schema" -}} +{{- .Values.router.sdp.schema | default "https" }} +{{- end }} + +{{/* +SDP port +*/}} +{{- define "sdp.port" -}} +{{- .Values.router.sdp.port | default "8000" }} +{{- end }} + +{{/* +SDP Metrics port +*/}} +{{- define "sdp.metricsPort" -}} +{{- .Values.router.sdp.metricsPort | default "8002" }} +{{- end }} + +{{/* +Define the full address to the SDP service. +*/}} +{{- define "sdp.serviceAddress" -}} +http://{{ include "sdp.fullname" . }}.{{ .Release.Namespace }}.svc.cluster.local:{{ include "sdp.port" . }} +{{- end -}} + +{{/* +TSS port +*/}} +{{- define "tss.port" -}} +{{- .Values.router.tss.port | default "9000" }} +{{- end }} + +{{/* +TSS Metrics port +*/}} +{{- define "tss.metricsPort" -}} +{{- .Values.router.tss.metricsPort | default "9002" }} +{{- end }} + +{{/* +Anchor Platform domain +*/}} +{{- define "sdp.ap.domain" -}} +{{- .Values.router.ap.domain | default "localhost" }} +{{- end }} + +{{/* +SDP domain schema +*/}} +{{- define "sdp.ap.schema" -}} +{{- .Values.router.ap.schema | default "https" }} +{{- end }} + +{{/* +Anchor SEP port +*/}} +{{- define "sdp.ap.sepPort" -}} +{{- .Values.router.ap.sepPort | default "8080" }} +{{- end }} + +{{/* +Anchor Platform port +*/}} +{{- define "sdp.ap.platformPort" -}} +{{- .Values.router.ap.platformPort | default "8085" }} +{{- end }} + + +{{/* +Anchor Platform metrics port +*/}} +{{- define "sdp.ap.metricsPort" -}} +{{- 8082 }} +{{- end }} + +{{/* +Define the full address to the AP SEP service. +*/}} +{{- define "sdp.ap.sepServiceAddress" -}} +http://{{ include "sdp.fullname" . }}-ap.{{ .Release.Namespace }}.svc.cluster.local:{{ include "sdp.ap.sepPort" . }} +{{- end -}} + +{{/* +Define the full address to the AP Platform service. +*/}} +{{- define "sdp.ap.platformServiceAddress" -}} +http://{{ include "sdp.fullname" . }}-ap.{{ .Release.Namespace }}.svc.cluster.local:{{ include "sdp.ap.platformPort" . }} +{{- end -}} + +{{/* +Is Pubnet? +*/}} +{{- define "isPubnet" -}} +{{- eq .Values.isPubnet true | default false }} +{{- end }} + +{{/* +Image Tag +*/}} +{{- define "imageTag" -}} +{{- .Values.image.tag | default .Chart.AppVersion }} +{{- end }} diff --git a/helmchart/sdp/templates/tests/test-connection.yaml b/helmchart/sdp/templates/tests/test-connection.yaml new file mode 100644 index 000000000..928fdb0e6 --- /dev/null +++ b/helmchart/sdp/templates/tests/test-connection.yaml @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: Pod +metadata: + name: "{{ include "sdp.fullname" . }}-test-connection" + namespace: {{ .Release.Namespace }} + labels: + {{- include "sdp.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": test +spec: + containers: + - name: wget + image: busybox + command: ['wget'] + args: ['{{ include "sdp.fullname" . }}:{{ include "sdp.port" . }}'] + restartPolicy: Never diff --git a/helmchart/sdp/values.yaml b/helmchart/sdp/values.yaml new file mode 100644 index 000000000..dedc800bf --- /dev/null +++ b/helmchart/sdp/values.yaml @@ -0,0 +1,269 @@ +# EXAMPLE values for the Stellar Disbursement Platform (SDP) Backend. +# Repo: https://github.com/stellar/stellar-disbursement-platform-backend +# +# This is a YAML-formatted file where you declare the variables to be passed into your templates. + + +# =========================== START Router =========================== +--- +router: + sdp: + schema: "https" + # The domain is the public address of that service. For localhost you will want to include the port as part of the domain. + # domain: "localhost:8000" + domain: sdp.localhost.com + port: "8000" + metricsPort: "8002" + tss: + schema: "https" + port: "9000" + metricsPort: "9002" + metricsType: "TSS_PROMETHEUS" + maxBaseFee: "100" + mock: "false" + numChannelAccounts: 1 + ap: + schema: "https" + # The domain is the public address of that service. For localhost you will want to include the port as part of the domain. + # domain: "localhost:8080" + domain: "ap.localhost.com" + sepPort: "8080" + platformPort: "8085" + +isPubnet: false + +# =========================== START Image =========================== +# replicaCount is used only if autoscaling.enabled is set to false: +replicaCount: 1 + +image: + # replace the `repository` with the actual image name. + repository: stellar/sdp-v2 + # for locally built images, use `pullPolicy: Never or IfNotPresent`. + pullPolicy: IfNotPresent + # If `tag` is set, it'll override the default value set in `.Chart.AppVersion`. + tag: "latest" + +imagePullSecrets: [] + +# =========================== START ConfigMap =========================== +# ConfigMap is used to configure the SDP: +configMap: + # Annotations to add to the ConfigMap + annotations: + fluxcd.io/ignore: "true" + # The data to be stored in the ConfigMap + data: + # auth + EC256_PUBLIC_KEY: "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEJ3HNphPAEKHvtRjsl5Kjwc9tTMqS\n2pmYNybrLsxZ6cuQvg2yiEoXZixP2cJ77csHClXC6cb1wQp/BNGDvGKoPg==\n-----END PUBLIC KEY-----" + # general + ENVIRONMENT: "localhost" + LOG_LEVEL: "TRACE" + # serve + SEP10_SIGNING_PUBLIC_KEY: GDA34JZ26FZY64XCSY46CUNSHLX762LHJXQHWWHGL5HSFRWSGBVHUFNI + DISTRIBUTION_PUBLIC_KEY: GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA + # serve metrics + METRICS_TYPE: "PROMETHEUS" + # message senders + EMAIL_SENDER_TYPE: DRY_RUN + SMS_SENDER_TYPE: DRY_RUN + # reCaptcha + RECAPTCHA_SITE_KEY: reCaptchaSiteKey + # CORS Allowed Origins - "*" value for PR preview purpose + CORS_ALLOWED_ORIGINS: "*" + # enable recaptcha for login and forget password + ENABLE_RECAPTCHA: "false" + # enable email-based MFA during login + ENABLE_MFA: "false" + + +# =========================== START Secret =========================== +secretName: "stellar-disbursement-platform-backend" # You need to create this secret manually +# Secrets used to configure the SDP: +# - ANCHOR_PLATFORM_OUTGOING_JWT_SECRET +# - AWS_ACCESS_KEY_ID +# - AWS_REGION +# - AWS_SECRET_ACCESS_KEY +# - AWS_SES_SENDER_ID +# - DATABASE_URL +# - DISTRIBUTION_SEED # TSS distribution account seed +# - EC256_PRIVATE_KEY +# - RECAPTCHA_SITE_SECRET_KEY +# - SEP10_SIGNING_PRIVATE_KEY +# - SEP24_JWT_SECRET +# - TWILIO_ACCOUNT_SID +# - TWILIO_AUTH_TOKEN +# - TWILIO_SERVICE_SID + +# =========================== START Service =========================== + +service: + type: ClusterIP + + +# =========================== START serviceAccount =========================== + +# Not used in SDF's deployment: +serviceAccount: + # Specifies whether a service account should be created + create: false + # Annotations to add to the service account + annotations: + fluxcd.io/ignore: "true" + # The name of the service account to use. + # If not set and create is true, a name is generated using the fullname template + name: "" + + +# =========================== START Deployment =========================== +deployment: + annotations: + fluxcd.io/ignore: "true" + podAnnotations: + prometheus.io/path: /metrics + prometheus.io/port: '{{ include "sdp.metricsPort" . }}' + prometheus.io/scrape: "true" + tssPodAnnotations: + prometheus.io/path: /metrics + prometheus.io/port: '{{ include "tss.metricsPort" . }}' + prometheus.io/scrape: "true" + podSecurityContext: {} + # fsGroup: 2000 + securityContext: {} + # capabilities: + # drop: + # - ALL + # readOnlyRootFilesystem: true + # runAsNonRoot: true + # runAsUser: 1000 + strategy: + # Ensure we upgrade 1 pod at a time to avoid migration races + type: "RollingUpdate" + rollingUpdate: + maxUnavailable: 0 + maxSurge: 1 + + nodeSelector: {} + + tolerations: [] + + affinity: {} + +resources: {} + # We usually recommend not to specify default resources and to leave this as a conscious + # choice for the user. This also increases chances charts run on environments with little + # resources, such as Minikube. If you do want to specify resources, uncomment the following + # lines, adjust them as necessary, and remove the curly braces after 'resources:'. + # limits: + # cpu: 250m + # memory: 512Mi + # requests: + # cpu: 50m + # memory: 256Mi + + +# =========================== START Autoscaling =========================== + +autoscaling: + enabled: false + minReplicas: 1 + maxReplicas: 4 + targetCPUUtilizationPercentage: 80 + targetMemoryUtilizationPercentage: 80 + + +# =========================== START Postgres (TESTING database) =========================== + +# `ephemeralDatabase` will create an ephemeral database for testing purposes. +# If this option is enabled, make sure to set the `DATABASE_URL` environment variable to: +# postgres://postgres:postgres@{{ include "sdp.fullname" . }}-psql:5432/postgres-sdp?sslmode=disable +# The AP database URL will be set automatically to: +# postgres://postgres:postgres@{{ include "sdp.fullname" . }}-psql:5433/postgres-ap?sslmode=disable +ephemeralDatabase: true + + +# =========================== START Ingress =========================== +ingress: + enabled: true + className: "nginx" + annotations: + fluxcd.io/ignore: "true" # Doesn't rely on flux to update the images from stellar/kube + # kubernetes.io/ingress.class: "public" + nginx.ingress.kubernetes.io/custom-response-headers: "X-XSS-Protection: 1; mode=block || X-Frame-Options: DENY || X-Content-Type-Options: nosniff || Strict-Transport-Security: max-age=31536000; includeSubDomains" + tls: + - hosts: + - '{{ include "sdp.domain" . }}' + - '{{ include "sdp.ap.domain" . }}' + secretName: stellar-disbursement-platform-backend-tls-cert # You need to create this secret manually. For more instructions, plz refer to helmchart/docs/README.md + # NOTE: the hosts to be used here will be the same ones as in the router section, at the top of this file. + + +# =========================== START AnchorPlatform =========================== +anchorPlatform: + deployment: + annotations: + fluxcd.io/ignore: "true" + podAnnotations: + prometheus.io/path: /metrics + prometheus.io/port: '{{ include "sdp.ap.metricsPort" . }}' + prometheus.io/scrape: "true" + strategy: + # Ensure we upgrade 1 pod at a time to avoid migration races + type: "RollingUpdate" + rollingUpdate: + maxUnavailable: 0 + maxSurge: 1 + podSecurityContext: + {} + securityContext: + {} + resources: + {} + + + # =========================== START anchorPlatform.configMap =========================== + configMap: + # Annotations to add to the ConfigMap + annotations: + fluxcd.io/ignore: "true" + # The data to be stored in the ConfigMap + data: + APP_LOGGING_LEVEL: INFO + # DATA_DATABASE: # will be automatically populated in the development helm chart when we have `ephemeralDatabase` enabled. + # DATA_SERVER: # will be automatically populated in the development helm chart when we have `ephemeralDatabase` enabled. + # DATA_FLYWAY_ENABLED: true # TODO: test this later + DATA_DDL_AUTO: update + METRICS_ENABLED: "false" # Metrics would be available at port 8082 + METRICS_EXTRAS_ENABLED: "false" + ASSETS_VALUE: | # TODO: update this later + { + "assets": [ + { + "sep24_enabled": true, + "schema": "stellar", + "code": "USDC", + "issuer": "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "distribution_account": "GDDSLDRLMIYZJOXPBWVTRU267TPXIJEYW6PSV7FMDBLFVZZI5AI4QV4F", + "significant_decimals": 7, + "deposit": { + "enabled": true, + "fee_minimum": 0, + "fee_percent": 0, + "min_amount": 1, + "max_amount": 10000 + }, + "withdraw": {"enabled": false} + } + ] + } + + # =========================== START anchorPlatform.secretName =========================== + secretName: "stellar-disbursement-platform-backend-ap" # you need to create this secret manually + # Secrets used to configure the AP: + # - SECRET_DATA_USERNAME + # - SECRET_DATA_PASSWORD + # - SECRET_SEP10_JWT_SECRET + # - SECRET_SEP10_SIGNING_SEED + # - SECRET_SEP24_INTERACTIVE_URL_JWT_SECRET + # - SECRET_SEP24_MORE_INFO_URL_JWT_SECRET + # - SECRET_PLATFORM_API_AUTH_SECRET diff --git a/internal/anchorplatform/jwt_manager.go b/internal/anchorplatform/jwt_manager.go new file mode 100644 index 000000000..1155448eb --- /dev/null +++ b/internal/anchorplatform/jwt_manager.go @@ -0,0 +1,116 @@ +package anchorplatform + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +var ErrInvalidToken = fmt.Errorf("invalid token") + +type JWTManager struct { + secret []byte + expirationMiliseconds int64 +} + +// NewJWTManager creates a new JWTManager instance based on the provided secret and expirationMiliseconds. +func NewJWTManager(secret string, expirationMiliseconds int64) (*JWTManager, error) { + const minSecretSize = 12 + if len(secret) < minSecretSize { + return nil, fmt.Errorf("secret is required to have at least %d characteres", minSecretSize) + } + + const minExpirationMiliseconds = 5000 + if expirationMiliseconds < minExpirationMiliseconds { + return nil, fmt.Errorf("expiration miliseconds is required to be at least %d", minExpirationMiliseconds) + } + + return &JWTManager{secret: []byte(secret), expirationMiliseconds: expirationMiliseconds}, nil +} + +// GenerateSEP24Token will generate a JWT token string using the token manager and the provided parameters. +// The parameters are validated before generating the token. +func (manager *JWTManager) GenerateSEP24Token(stellarAccount, stellarMemo, clientDomain, transactionID string) (string, error) { + subject := stellarAccount + if stellarMemo != "" { + subject = fmt.Sprintf("%s:%s", stellarAccount, stellarMemo) + } + + claims := SEP24JWTClaims{ + ClientDomainClaim: clientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: transactionID, + Subject: subject, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Millisecond * time.Duration(manager.expirationMiliseconds))), + }, + } + err := claims.Valid() + if err != nil { + return "", fmt.Errorf("validating SEP24 token claims: %w", err) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString(manager.secret) + if err != nil { + return "", fmt.Errorf("signing SEP24 token: %w", err) + } + + return signedToken, nil +} + +// ParseSEP24TokenClaims will parse the provided token string and return the SEP24JWTClaims, if possible. +// If the token is not a valid SEP-24 token, an error is returned instead. +func (manager *JWTManager) ParseSEP24TokenClaims(tokenString string) (*SEP24JWTClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &SEP24JWTClaims{}, func(token *jwt.Token) (interface{}, error) { + return manager.secret, nil + }) + if err != nil { + return nil, fmt.Errorf("parsing SEP24 token: %w", err) + } + + claims, ok := token.Claims.(*SEP24JWTClaims) + if !ok || !token.Valid { + return nil, ErrInvalidToken + } + + return claims, nil +} + +// GenerateDefaultToken will generate a JWT token string using the token manager and only the default claims. +func (manager *JWTManager) GenerateDefaultToken(id string) (string, error) { + claims := jwt.RegisteredClaims{ + ID: id, + Subject: "stellar-disbursement-platform-backend", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Millisecond * time.Duration(manager.expirationMiliseconds))), + } + err := claims.Valid() + if err != nil { + return "", fmt.Errorf("validating token claims: %w", err) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, err := token.SignedString(manager.secret) + if err != nil { + return "", fmt.Errorf("signing default token: %w", err) + } + + return signedToken, nil +} + +// ParseDefaultTokenClaims will parse the default claims from a JWT token string. +func (manager *JWTManager) ParseDefaultTokenClaims(tokenString string) (*jwt.RegisteredClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + return manager.secret, nil + }) + if err != nil { + return nil, fmt.Errorf("parsing default token: %w", err) + } + + claims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok || !token.Valid { + return nil, ErrInvalidToken + } + + return claims, nil +} diff --git a/internal/anchorplatform/jwt_manager_test.go b/internal/anchorplatform/jwt_manager_test.go new file mode 100644 index 000000000..61fad054d --- /dev/null +++ b/internal/anchorplatform/jwt_manager_test.go @@ -0,0 +1,75 @@ +package anchorplatform + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewJWTManager(t *testing.T) { + jwtManager, err := NewJWTManager("", 0) + require.Nil(t, jwtManager) + require.EqualError(t, err, "secret is required to have at least 12 characteres") + + jwtManager, err = NewJWTManager("1234567890ab", 0) + require.Nil(t, jwtManager) + require.EqualError(t, err, "expiration miliseconds is required to be at least 5000") + + jwtManager, err = NewJWTManager("1234567890ab", 5000) + require.NotNil(t, jwtManager) + require.NoError(t, err) + wantManager := &JWTManager{ + secret: []byte("1234567890ab"), + expirationMiliseconds: 5000, + } + require.Equal(t, wantManager, jwtManager) +} + +func Test_JWTManager_GenerateAndParseSEP24Token(t *testing.T) { + jwtManager, err := NewJWTManager("1234567890ab", 5000) + require.NoError(t, err) + + // invalid claims + tokenStr, err := jwtManager.GenerateSEP24Token("", "", "test.com", "test-transaction-id") + require.EqualError(t, err, "validating SEP24 token claims: stellar account is invalid: strkey is 0 bytes long; minimum valid length is 5") + require.Empty(t, tokenStr) + + // valid claims πŸŽ‰ + tokenStr, err = jwtManager.GenerateSEP24Token("GB54GWWWOSHATX5ALKHBBL2IQBZ2E7TBFO7F7VXKPIW6XANYDK4Y3RRC", "123456", "test.com", "test-transaction-id") + require.NoError(t, err) + require.NotEmpty(t, tokenStr) + now := time.Now() + + // parse claims + claims, err := jwtManager.ParseSEP24TokenClaims(tokenStr) + require.NoError(t, err) + assert.Nil(t, claims.Valid()) + assert.Equal(t, "test-transaction-id", claims.TransactionID()) + assert.Equal(t, "GB54GWWWOSHATX5ALKHBBL2IQBZ2E7TBFO7F7VXKPIW6XANYDK4Y3RRC", claims.SEP10StellarAccount()) + assert.Equal(t, "123456", claims.SEP10StellarMemo()) + assert.Equal(t, "test.com", claims.ClientDomain()) + assert.True(t, claims.ExpiresAt().After(now.Add(time.Duration(4000*time.Millisecond)))) + assert.True(t, claims.ExpiresAt().Before(now.Add(time.Duration(5000*time.Millisecond)))) +} + +func Test_JWTManager_GenerateAndParseDefaultToken(t *testing.T) { + jwtManager, err := NewJWTManager("1234567890ab", 5000) + require.NoError(t, err) + + // valid claims πŸŽ‰ + tokenStr, err := jwtManager.GenerateDefaultToken("test-transaction-id") + require.NoError(t, err) + require.NotEmpty(t, tokenStr) + now := time.Now() + + // parse claims + claims, err := jwtManager.ParseDefaultTokenClaims(tokenStr) + require.NoError(t, err) + assert.Nil(t, claims.Valid()) + assert.Equal(t, "test-transaction-id", claims.ID) + assert.Equal(t, "stellar-disbursement-platform-backend", claims.Subject) + assert.True(t, claims.ExpiresAt.After(now.Add(time.Duration(4000*time.Millisecond)))) + assert.True(t, claims.ExpiresAt.Before(now.Add(time.Duration(5000*time.Millisecond)))) +} diff --git a/internal/anchorplatform/platform_api_service.go b/internal/anchorplatform/platform_api_service.go new file mode 100644 index 000000000..0e91ec810 --- /dev/null +++ b/internal/anchorplatform/platform_api_service.go @@ -0,0 +1,126 @@ +package anchorplatform + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" +) + +var ErrJWTManagerNotSet = fmt.Errorf("jwt manager not set") + +// TODO update with the PlatformAPI endpoints +type AnchorPlatformAPIServiceInterface interface { + UpdateAnchorTransactions(ctx context.Context, transactions []Transaction) error +} + +type AnchorPlatformAPIService struct { + HttpClient httpclient.HttpClientInterface + AnchorPlatformBasePlatformURL string + jwtManager *JWTManager +} + +type TransactionValues struct { + ID string `json:"id"` + Status string `json:"status,omitempty"` + Sep string `json:"sep,omitempty"` + Kind string `json:"kind,omitempty"` + DestinationAccount string `json:"destination_account,omitempty"` + Memo string `json:"memo,omitempty"` + KYCVerified bool `json:"kyc_verified,omitempty"` +} + +type Transaction struct { + TransactionValues TransactionValues `json:"transaction"` +} + +type TransactionRecords struct { + Transactions []Transaction `json:"records"` +} + +func NewAnchorPlatformAPIService(httpClient httpclient.HttpClientInterface, anchorPlatformBasePlatformURL, anchorPlatformOutgoingJWTSecret string) (*AnchorPlatformAPIService, error) { + apService := AnchorPlatformAPIService{ + HttpClient: httpClient, + AnchorPlatformBasePlatformURL: anchorPlatformBasePlatformURL, + } + + if anchorPlatformOutgoingJWTSecret != "" { + const expirationMiliseconds = 5000 + jwtManager, err := NewJWTManager(anchorPlatformOutgoingJWTSecret, expirationMiliseconds) + if err != nil { + return nil, fmt.Errorf("error creating jwt manager: %w", err) + } + apService.jwtManager = jwtManager + } + + return &apService, nil +} + +func (a *AnchorPlatformAPIService) UpdateAnchorTransactions(ctx context.Context, transactions []Transaction) error { + records := TransactionRecords{transactions} + + recordsJSON, err := json.Marshal(records) + if err != nil { + return fmt.Errorf("error marshaling records: %w", err) + } + + u, err := url.JoinPath(a.AnchorPlatformBasePlatformURL, "transactions") + if err != nil { + return fmt.Errorf("error creating url: %w", err) + } + request, err := http.NewRequestWithContext(ctx, http.MethodPatch, u, strings.NewReader(string(recordsJSON))) + if err != nil { + return fmt.Errorf("error creating new request: %w", err) + } + request.Header.Set("Content-Type", "application/json") + + // If the service is configured with an outgoing JWT secret, we'll generate a JWT token and add it to the request. + token, err := a.GetJWTToken(transactions) + if err != nil { + if !errors.Is(err, ErrJWTManagerNotSet) { + return fmt.Errorf("error getting jwt token in UpdateAnchorTransactions: %w", err) + } + log.Ctx(ctx).Warn("JWT secret not set, skipping JWT token generation") + } else { + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + } + + response, err := a.HttpClient.Do(request) + if err != nil { + return fmt.Errorf("error making request to anchor platform: %w", err) + } + + if response.StatusCode/100 != 2 { + return fmt.Errorf("error updating transaction on anchor platform, response.StatusCode: %d", response.StatusCode) + } + + return nil +} + +// GetJWTToken will generate a JWT token if the service is configured with an outgoing JWT secret. +func (a *AnchorPlatformAPIService) GetJWTToken(transactions []Transaction) (string, error) { + if a.jwtManager == nil { + return "", ErrJWTManagerNotSet + } + + var txIDs []string + for _, tx := range transactions { + txIDs = append(txIDs, tx.TransactionValues.ID) + } + + token, err := a.jwtManager.GenerateDefaultToken(strings.Join(txIDs, ",")) + if err != nil { + return "", fmt.Errorf("error generating jwt token: %w", err) + } + + return token, nil +} + +// Ensuring that AnchorPlatformAPIService is implementing AnchorPlatformAPIServiceInterface. +var _ AnchorPlatformAPIServiceInterface = (*AnchorPlatformAPIService)(nil) diff --git a/internal/anchorplatform/platform_api_service_mock.go b/internal/anchorplatform/platform_api_service_mock.go new file mode 100644 index 000000000..06bc0ecf3 --- /dev/null +++ b/internal/anchorplatform/platform_api_service_mock.go @@ -0,0 +1,18 @@ +package anchorplatform + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type AnchorPlatformAPIServiceMock struct { + mock.Mock +} + +func (a *AnchorPlatformAPIServiceMock) UpdateAnchorTransactions(ctx context.Context, transactions []Transaction) error { + args := a.Called(ctx, transactions) + return args.Error(0) +} + +var _ AnchorPlatformAPIServiceInterface = (*AnchorPlatformAPIServiceMock)(nil) diff --git a/internal/anchorplatform/platform_api_service_test.go b/internal/anchorplatform/platform_api_service_test.go new file mode 100644 index 000000000..5eabf1743 --- /dev/null +++ b/internal/anchorplatform/platform_api_service_test.go @@ -0,0 +1,135 @@ +package anchorplatform + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_UpdateAnchorTransactions(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + anchorPlatformAPIService, err := NewAnchorPlatformAPIService(&httpClientMock, "http://mock_anchor.com/", "") + require.NoError(t, err) + ctx := context.Background() + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + transaction := &Transaction{ + TransactionValues: TransactionValues{ + ID: "test-transaction-id", + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: "stellar_address", + Memo: "stellar_memo", + KYCVerified: true, + }, + } + err := anchorPlatformAPIService.UpdateAnchorTransactions(ctx, []Transaction{*transaction}) + require.EqualError(t, err, "error making request to anchor platform: error calling the request") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to update transactions on anchor platform", func(t *testing.T) { + transactionResponse := `{The 'id' of the transaction first determined to be invalid.}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.Anything).Return(response, nil).Once() + + transaction := &Transaction{ + TransactionValues: TransactionValues{ + ID: "test-transaction-id", + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: "stellar_address", + Memo: "stellar_memo", + KYCVerified: true, + }, + } + err := anchorPlatformAPIService.UpdateAnchorTransactions(ctx, []Transaction{*transaction}) + require.EqualError(t, err, "error updating transaction on anchor platform, response.StatusCode: 400") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully update transaction on anchor platform", func(t *testing.T) { + transactionResponse := `{ + "transaction":{ + "id": "test-transaction-id", + "status": "pending_anchor", + "sep": "24", + "kind": "deposit", + "destination_account": "stellar_address", + "memo": "stellar_memo" + "kyc_verified": true, + } + }` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.Anything).Return(response, nil).Once() + + transaction := &Transaction{ + TransactionValues: TransactionValues{ + ID: "test-transaction-id", + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: "stellar_address", + Memo: "stellar_memo", + KYCVerified: true, + }, + } + err := anchorPlatformAPIService.UpdateAnchorTransactions(ctx, []Transaction{*transaction}) + require.NoError(t, err) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_GetJWTToken(t *testing.T) { + t.Run("returns ErrJWTSecretNotSet when a JWT secret is not set", func(t *testing.T) { + apService := AnchorPlatformAPIService{} + transactions := []Transaction{ + {TransactionValues{ID: "1"}}, + {TransactionValues{ID: "2"}}, + } + token, err := apService.GetJWTToken(transactions) + require.ErrorIs(t, err, ErrJWTManagerNotSet) + require.Empty(t, token) + }) + + t.Run("returns token successfully πŸŽ‰", func(t *testing.T) { + jwtManager, err := NewJWTManager("1234567890ab", 5000) + require.NoError(t, err) + + apService := AnchorPlatformAPIService{jwtManager: jwtManager} + transactions := []Transaction{ + {TransactionValues{ID: "1"}}, + {TransactionValues{ID: "2"}}, + } + token, err := apService.GetJWTToken(transactions) + require.NoError(t, err) + require.NotEmpty(t, token) + + // verify the token + claims, err := jwtManager.ParseDefaultTokenClaims(token) + require.NoError(t, err) + assert.Nil(t, claims.Valid()) + assert.Equal(t, "1,2", claims.ID) + assert.Equal(t, "stellar-disbursement-platform-backend", claims.Subject) + }) +} diff --git a/internal/anchorplatform/sep24_auth_middleware.go b/internal/anchorplatform/sep24_auth_middleware.go new file mode 100644 index 000000000..66b1615b0 --- /dev/null +++ b/internal/anchorplatform/sep24_auth_middleware.go @@ -0,0 +1,147 @@ +package anchorplatform + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/stellar/go/network" + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" +) + +type ContextType string + +const SEP24ClaimsContextKey ContextType = "sep24_claims" + +func GetSEP24Claims(ctx context.Context) *SEP24JWTClaims { + claims := ctx.Value(SEP24ClaimsContextKey) + if claims == nil { + return nil + } + return claims.(*SEP24JWTClaims) +} + +type SEP24RequestQuery struct { + Token string `query:"token"` + TransactionID string `query:"transaction_id"` +} + +// checkSEP24ClientDomain check if the sep24 token has a client domain and if not check in which network the API is running on, +// only testnet can have an empty client domain. +func checkSEP24ClientDomain(ctx context.Context, sep24Claims *SEP24JWTClaims, networkPassphrase string) error { + if sep24Claims.ClientDomain() == "" { + missingDomain := "missing client domain in the token claims" + if networkPassphrase == network.PublicNetworkPassphrase { + log.Ctx(ctx).Error(missingDomain) + return fmt.Errorf(missingDomain) + } + log.Ctx(ctx).Warn(missingDomain) + } + return nil +} + +// SEP24QueryTokenAuthenticateMiddleware is a middleware that validates if the token passed in as a query +// parameter with ?token={token} is valid for the authenticated endpoints. +func SEP24QueryTokenAuthenticateMiddleware(jwtManager *JWTManager, networkPassphrase string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + // get the token from the request query parameters + var reqParams SEP24RequestQuery + if err := httpdecode.DecodeQuery(req, &reqParams); err != nil { + err = fmt.Errorf("decoding the request query parameters: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + // check if the token is present + if reqParams.Token == "" { + log.Ctx(ctx).Error("no token was provided in the request") + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + // parse the token claims + sep24Claims, err := jwtManager.ParseSEP24TokenClaims(reqParams.Token) + if err != nil { + err = fmt.Errorf("parsing the token claims: %w", err) + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(rw) + return + } + + // check if the transaction_id in the token claims matches the transaction_id in the request query parameters + if sep24Claims.TransactionID() != reqParams.TransactionID { + log.Ctx(ctx).Error("the transaction_id in the token claims does not match the transaction_id in the request query parameters") + httperror.BadRequest("", nil, nil).Render(rw) + return + } + + err = checkSEP24ClientDomain(ctx, sep24Claims, networkPassphrase) + if err != nil { + httperror.BadRequest("", err, nil).Render(rw) + return + } + + // Add the token to the request context + ctx = context.WithValue(ctx, SEP24ClaimsContextKey, sep24Claims) + req = req.WithContext(ctx) + + next.ServeHTTP(rw, req) + }) + } +} + +// SEP24HeaderTokenAuthenticateMiddleware is a middleware that validates if the token passed in +// the 'Authorization' header is valid for the authenticated endpoints. +func SEP24HeaderTokenAuthenticateMiddleware(jwtManager *JWTManager, networkPassphrase string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + // get the token from the Authorization header + authHeader := req.Header.Get("Authorization") + // check if the Authorization header is present + if authHeader == "" { + log.Ctx(ctx).Error("no token was provided in the Authorization header") + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + // check if the Authorization header has two parts ['Bearer', token] + if !strings.HasPrefix(authHeader, "Bearer ") { + log.Ctx(ctx).Error("invalid Authorization header provided") + httperror.Unauthorized("Invalid Authorization header provided.", nil, nil).Render(rw) + return + } + + // parse the token claims + token := strings.Replace(authHeader, "Bearer ", "", 1) + sep24Claims, err := jwtManager.ParseSEP24TokenClaims(token) + if err != nil { + err = fmt.Errorf("parsing the token claims: %w", err) + log.Ctx(ctx).Error(err) + + httperror.Unauthorized("", err, nil).Render(rw) + return + } + + err = checkSEP24ClientDomain(ctx, sep24Claims, networkPassphrase) + if err != nil { + httperror.BadRequest("", err, nil).Render(rw) + return + } + + // Add the token to the request context + ctx = context.WithValue(ctx, SEP24ClaimsContextKey, sep24Claims) + req = req.WithContext(ctx) + + next.ServeHTTP(rw, req) + }) + } +} diff --git a/internal/anchorplatform/sep24_auth_middleware_test.go b/internal/anchorplatform/sep24_auth_middleware_test.go new file mode 100644 index 000000000..7e49eb233 --- /dev/null +++ b/internal/anchorplatform/sep24_auth_middleware_test.go @@ -0,0 +1,607 @@ +package anchorplatform + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/golang-jwt/jwt/v4" + "github.com/stellar/go/network" + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/require" +) + +func Test_GetSEP24Claims(t *testing.T) { + ctx := context.Background() + gotClaims := GetSEP24Claims(ctx) + require.Nil(t, gotClaims) + + wantClaims := &SEP24JWTClaims{ + ClientDomainClaim: "test.com", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444:123456", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Second)), + }, + } + ctx = context.WithValue(ctx, SEP24ClaimsContextKey, wantClaims) + + gotClaims = GetSEP24Claims(ctx) + require.Equal(t, wantClaims, gotClaims) +} + +func Test_SEP24UnauthenticatedRoutes(t *testing.T) { + r := chi.NewRouter() + + r.Get("/unauthenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + t.Run("doesn't return Unauthorized for unauthenticated routes", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/unauthenticated", nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) +} + +func Test_SEP24QueryTokenAuthenticateMiddleware(t *testing.T) { + tokenSecret := "jwt_secret_1234567890" + r := chi.NewRouter() + jwtManager, err := NewJWTManager(tokenSecret, 5000) + require.NoError(t, err) + + r.Group(func(r chi.Router) { + r.Use(SEP24QueryTokenAuthenticateMiddleware(jwtManager, network.TestNetworkPassphrase)) + + r.Get("/authenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + }) + + t.Run("returns Unauthorized for authenticated routes without token", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "no token was provided in the request") + }) + + t.Run("returns Unauthorized if the jwt could not be parsed", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodGet, "/authenticated?token=123", nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: token contains an invalid number of segments") + }) + + t.Run("returns Unauthorized if the jwt is expired", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + expiredToken := "eyJjbGllbnRfZG9tYWluIjoidGVzdC5jb20iLCJzdWIiOiJHQkxUWEY0NkpUQ0dNV0ZKQVNRTFZYTU1BMzZJUFlURENONEVONzNIUlhDR0RDR1lCWk0zQTQ0NCIsImV4cCI6MTY4MTQxMDkzMiwianRpIjoidGVzdC10cmFuc2FjdGlvbi1pZCJ9.RThqCuWkjBr1xw8LOBogDmw8RyMnrELDkA-w4Jv5x_E" + req, err := http.NewRequest(http.MethodGet, "/authenticated?token="+expiredToken, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: token contains an invalid number of segments") + }) + + t.Run("returns Unauthorized if the token is valid but the transaction_id is not different from what's expected", func(t *testing.T) { + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "test.com", "test-transaction-id") + require.NoError(t, err) + + urlStr := fmt.Sprintf("/authenticated?transaction_id=%s&token=%s", "invalid-transaction-id", validToken) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.JSONEq(t, `{"error":"The request was invalid in some way."}`, string(respBody)) + }) + + t.Run("returns Unauthorized if the jwt expiration is good but another parameter (stellar account) is weird", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + // create a token with an odd subject (stellar_account:memo) + badClaims := SEP24JWTClaims{ + ClientDomainClaim: "test.com", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "bad-subject", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Second)), + }, + } + tokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, badClaims) + badToken, err := tokenObj.SignedString([]byte(tokenSecret)) + require.NoError(t, err) + + urlStr := fmt.Sprintf("/authenticated?transaction_id=%s&token=%s", "test-transaction-id", badToken) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: stellar account is invalid: non-canonical strkey; unused leftover character") + }) + + t.Run("returns Unauthorized if the jwt was signed with a different secret", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + // create a token with an odd subject (stellar_account:memo) + anotherTokenSecret := tokenSecret + "another" + anotherJWTManager, err := NewJWTManager(anotherTokenSecret, 5000) + require.NoError(t, err) + tokenWithDifferentSigner, err := anotherJWTManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "test.com", "valid-transaction-id") + require.NoError(t, err) + + urlStr := fmt.Sprintf("/authenticated?transaction_id=%s&token=%s", "valid-transaction-id", tokenWithDifferentSigner) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: signature is invalid") + }) + + t.Run("both the token and the transaction_id are valid πŸŽ‰", func(t *testing.T) { + var contextClaims *SEP24JWTClaims + require.Nil(t, contextClaims) + r.With(SEP24QueryTokenAuthenticateMiddleware(jwtManager, network.TestNetworkPassphrase)).Get("/authenticated_success", func(w http.ResponseWriter, r *http.Request) { + contextClaims = r.Context().Value(SEP24ClaimsContextKey).(*SEP24JWTClaims) + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + now := time.Now() + validTransactionID := "valid-transaction-id" + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "test.com", validTransactionID) + require.NoError(t, err) + + urlStr := fmt.Sprintf("/authenticated_success?transaction_id=%s&token=%s", validTransactionID, validToken) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.JSONEq(t, `{"status":"ok"}`, string(respBody)) + + // validate the context claims + require.NotNil(t, contextClaims) + require.Equal(t, "test.com", contextClaims.ClientDomain()) + require.Equal(t, "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", contextClaims.SEP10StellarAccount()) + require.Equal(t, validTransactionID, contextClaims.TransactionID()) + require.Empty(t, contextClaims.SEP10StellarMemo()) + require.True(t, contextClaims.ExpiresAt().After(now.Add(time.Duration(4000*time.Millisecond)))) + require.True(t, contextClaims.ExpiresAt().Before(now.Add(time.Duration(5000*time.Millisecond)))) + }) + + t.Run("token with empty client domain but valid in testnet πŸŽ‰", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.WarnLevel) + + var contextClaims *SEP24JWTClaims + require.Nil(t, contextClaims) + r.With(SEP24QueryTokenAuthenticateMiddleware(jwtManager, network.TestNetworkPassphrase)).Get("/authenticated_testnet", func(w http.ResponseWriter, r *http.Request) { + contextClaims = r.Context().Value(SEP24ClaimsContextKey).(*SEP24JWTClaims) + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + validTransactionID := "valid-transaction-id" + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "", validTransactionID) + require.NoError(t, err) + + urlStr := fmt.Sprintf("/authenticated_testnet?transaction_id=%s&token=%s", validTransactionID, validToken) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.JSONEq(t, `{"status":"ok"}`, string(respBody)) + + // check client domain + require.Empty(t, contextClaims.ClientDomain()) + + // validate logs + require.Contains(t, buf.String(), "missing client domain in the token claims") + }) + + t.Run("token with empty client domain returns error in pubnet πŸŽ‰", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + r.With(SEP24QueryTokenAuthenticateMiddleware(jwtManager, network.PublicNetworkPassphrase)).Get("/authenticated_pubnet", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + validTransactionID := "valid-transaction-id" + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "", validTransactionID) + require.NoError(t, err) + + urlStr := fmt.Sprintf("/authenticated_pubnet?transaction_id=%s&token=%s", validTransactionID, validToken) + req, err := http.NewRequest(http.MethodGet, urlStr, nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.JSONEq(t, `{"error":"The request was invalid in some way."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "missing client domain in the token claims") + }) +} + +func Test_SEP24HeaderTokenAuthenticateMiddleware(t *testing.T) { + tokenSecret := "jwt_secret_1234567890" + r := chi.NewRouter() + jwtManager, err := NewJWTManager(tokenSecret, 5000) + require.NoError(t, err) + + r.Group(func(r chi.Router) { + r.Use(SEP24HeaderTokenAuthenticateMiddleware(jwtManager, network.TestNetworkPassphrase)) + + r.Get("/authenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + }) + + t.Run("returns Unauthorized for authenticated routes without token", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "no token was provided in the Authorization header") + }) + + t.Run("returns Unauthorized if the authorization header is invalid", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "InvalidToken") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Invalid Authorization header provided."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "invalid Authorization header provided") + }) + + t.Run("returns Unauthorized if the jwt could not be parsed", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer 123") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: token contains an invalid number of segments") + }) + + t.Run("returns Unauthorized if the jwt is expired", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + expiredToken := "eyJjbGllbnRfZG9tYWluIjoidGVzdC5jb20iLCJzdWIiOiJHQkxUWEY0NkpUQ0dNV0ZKQVNRTFZYTU1BMzZJUFlURENONEVONzNIUlhDR0RDR1lCWk0zQTQ0NCIsImV4cCI6MTY4MTQxMDkzMiwianRpIjoidGVzdC10cmFuc2FjdGlvbi1pZCJ9.RThqCuWkjBr1xw8LOBogDmw8RyMnrELDkA-w4Jv5x_E" + authHeader := "Bearer " + expiredToken + req.Header.Set("Authorization", authHeader) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: token contains an invalid number of segments") + }) + + t.Run("returns Unauthorized if the jwt expiration is good but another parameter (stellar account) is weird", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + // create a token with an odd subject (stellar_account:memo) + badClaims := SEP24JWTClaims{ + ClientDomainClaim: "test.com", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "bad-subject", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Second)), + }, + } + tokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, badClaims) + badToken, err := tokenObj.SignedString([]byte(tokenSecret)) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + authHeader := "Bearer " + badToken + req.Header.Set("Authorization", authHeader) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: stellar account is invalid: non-canonical strkey; unused leftover character") + }) + + t.Run("returns Unauthorized if the jwt was signed with a different secret", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + // create a token signed with a different secret + anotherTokenSecret := tokenSecret + "another" + anotherJWTManager, err := NewJWTManager(anotherTokenSecret, 5000) + require.NoError(t, err) + tokenWithDifferentSigner, err := anotherJWTManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "test.com", "valid-transaction-id") + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + authHeader := "Bearer " + tokenWithDifferentSigner + req.Header.Set("Authorization", authHeader) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, resp.StatusCode) + require.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "parsing the token claims: parsing SEP24 token: signature is invalid") + }) + + t.Run("token is valid πŸŽ‰", func(t *testing.T) { + var contextClaims *SEP24JWTClaims + require.Nil(t, contextClaims) + r.With(SEP24HeaderTokenAuthenticateMiddleware(jwtManager, network.TestNetworkPassphrase)).Get("/authenticated_success", func(w http.ResponseWriter, r *http.Request) { + contextClaims = r.Context().Value(SEP24ClaimsContextKey).(*SEP24JWTClaims) + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + now := time.Now() + validTransactionID := "valid-transaction-id" + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "test.com", validTransactionID) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/authenticated_success", nil) + require.NoError(t, err) + authHeader := "Bearer " + validToken + req.Header.Set("Authorization", authHeader) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.JSONEq(t, `{"status":"ok"}`, string(respBody)) + + // validate the context claims + require.NotNil(t, contextClaims) + require.Equal(t, "test.com", contextClaims.ClientDomain()) + require.Equal(t, "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", contextClaims.SEP10StellarAccount()) + require.Equal(t, validTransactionID, contextClaims.TransactionID()) + require.Empty(t, contextClaims.SEP10StellarMemo()) + require.True(t, contextClaims.ExpiresAt().After(now.Add(time.Duration(4000*time.Millisecond)))) + require.True(t, contextClaims.ExpiresAt().Before(now.Add(time.Duration(5000*time.Millisecond)))) + }) + + t.Run("token with empty client domain is valid in testnet πŸŽ‰", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.WarnLevel) + + var contextClaims *SEP24JWTClaims + require.Nil(t, contextClaims) + r.With(SEP24HeaderTokenAuthenticateMiddleware(jwtManager, network.TestNetworkPassphrase)).Get("/authenticated_testnet", func(w http.ResponseWriter, r *http.Request) { + contextClaims = r.Context().Value(SEP24ClaimsContextKey).(*SEP24JWTClaims) + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + validTransactionID := "valid-transaction-id" + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "", validTransactionID) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/authenticated_testnet", nil) + require.NoError(t, err) + authHeader := "Bearer " + validToken + req.Header.Set("Authorization", authHeader) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.JSONEq(t, `{"status":"ok"}`, string(respBody)) + + // check client domain + require.Empty(t, contextClaims.ClientDomain()) + + // validate logs + require.Contains(t, buf.String(), "missing client domain in the token claims") + }) + + t.Run("token with empty client domain returns error in pubnet πŸŽ‰", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + r.With(SEP24HeaderTokenAuthenticateMiddleware(jwtManager, network.PublicNetworkPassphrase)).Get("/authenticated_testnet", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + validTransactionID := "valid-transaction-id" + validToken, err := jwtManager.GenerateSEP24Token("GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", "", "", validTransactionID) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/authenticated_testnet", nil) + require.NoError(t, err) + authHeader := "Bearer " + validToken + req.Header.Set("Authorization", authHeader) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.JSONEq(t, `{"error":"The request was invalid in some way."}`, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "missing client domain in the token claims") + }) +} diff --git a/internal/anchorplatform/sep24_jwt_claims.go b/internal/anchorplatform/sep24_jwt_claims.go new file mode 100644 index 000000000..7e8d9a65d --- /dev/null +++ b/internal/anchorplatform/sep24_jwt_claims.go @@ -0,0 +1,80 @@ +package anchorplatform + +import ( + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stellar/go/keypair" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type SEP24JWTClaims struct { + // Fields expected according with https://github.com/stellar/java-stellar-anchor-sdk/blob/bfa9b1d735f099bc6a21f0b9c55bd381a50c16b8/platform/src/main/java/org/stellar/anchor/platform/service/SimpleInteractiveUrlConstructor.java#L47-L56 + ClientDomainClaim string `json:"client_domain"` + jwt.RegisteredClaims +} + +func (c *SEP24JWTClaims) TransactionID() string { + return c.ID +} + +func (c *SEP24JWTClaims) SEP10StellarAccount() string { + // The SEP-10 account will be in the format "account:memo", in case there's a memo. + // That's why we'll split the string on ":" and get the first element. + // ref: https://github.com/stellar/java-stellar-anchor-sdk/blob/bfa9b1d735f099bc6a21f0b9c55bd381a50c16b8/platform/src/main/java/org/stellar/anchor/platform/service/SimpleInteractiveUrlConstructor.java#L47-L50 + splits := strings.Split(c.Subject, ":") + return splits[0] +} + +func (c *SEP24JWTClaims) SEP10StellarMemo() string { + // The SEP-10 account will be in the format "account:memo", in case there's a memo. + // That's why we'll split the string on ":" and get the second element. + // ref: https://github.com/stellar/java-stellar-anchor-sdk/blob/bfa9b1d735f099bc6a21f0b9c55bd381a50c16b8/platform/src/main/java/org/stellar/anchor/platform/service/SimpleInteractiveUrlConstructor.java#L47-L50 + splits := strings.Split(c.Subject, ":") + if len(splits) > 1 { + return splits[1] + } + return "" +} + +func (c *SEP24JWTClaims) ExpiresAt() *time.Time { + if c.RegisteredClaims.ExpiresAt == nil { + return nil + } + return &c.RegisteredClaims.ExpiresAt.Time +} + +func (c *SEP24JWTClaims) ClientDomain() string { + return c.ClientDomainClaim +} + +func (c SEP24JWTClaims) Valid() error { + if c.ExpiresAt() == nil { + return fmt.Errorf("expires_at is required") + } + + err := c.RegisteredClaims.Valid() + if err != nil { + return fmt.Errorf("validating registered claims: %w", err) + } + + if c.TransactionID() == "" { + return fmt.Errorf("transaction_id is required") + } + + _, err = keypair.ParseAddress(c.SEP10StellarAccount()) + if err != nil { + return fmt.Errorf("stellar account is invalid: %w", err) + } + + if c.ClientDomain() != "" { + err = utils.ValidateDNS(c.ClientDomain()) + if err != nil { + return fmt.Errorf("client_domain is invalid: %w", err) + } + } + + return nil +} diff --git a/internal/anchorplatform/sep24_jwt_claims_test.go b/internal/anchorplatform/sep24_jwt_claims_test.go new file mode 100644 index 000000000..f6b57c470 --- /dev/null +++ b/internal/anchorplatform/sep24_jwt_claims_test.go @@ -0,0 +1,68 @@ +package anchorplatform + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" +) + +func Test_SEP24JWTClaims_getters(t *testing.T) { + expiresAt := jwt.NewNumericDate(time.Now().Add(time.Minute * 5)) + claims := SEP24JWTClaims{ + ClientDomainClaim: "test.com", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GB54GWWWOSHATX5ALKHBBL2IQBZ2E7TBFO7F7VXKPIW6XANYDK4Y3RRC:123456", + ExpiresAt: expiresAt, + }, + } + + require.Equal(t, "test-transaction-id", claims.TransactionID()) + require.Equal(t, "GB54GWWWOSHATX5ALKHBBL2IQBZ2E7TBFO7F7VXKPIW6XANYDK4Y3RRC", claims.SEP10StellarAccount()) + require.Equal(t, "123456", claims.SEP10StellarMemo()) + require.Equal(t, "test.com", claims.ClientDomain()) + require.Equal(t, expiresAt.Time, *claims.ExpiresAt()) +} + +func Test_SEP24JWTClaims_valid(t *testing.T) { + // empty claims + claims := SEP24JWTClaims{} + err := claims.Valid() + require.EqualError(t, err, "expires_at is required") + + // expired claims + now := time.Now() + fiveMinAgo := now.Add(time.Minute * -5) + claims.RegisteredClaims.ExpiresAt = jwt.NewNumericDate(fiveMinAgo) + err = claims.Valid() + require.Contains(t, err.Error(), "validating registered claims: token is expired by 5m0") + + // missing transaction ID + fiveMinFromNow := now.Add(time.Minute * 5) + claims.RegisteredClaims.ExpiresAt = jwt.NewNumericDate(fiveMinFromNow) + err = claims.Valid() + require.EqualError(t, err, "transaction_id is required") + + // missing subject + claims.ID = "test-transaction-id" + err = claims.Valid() + require.EqualError(t, err, "stellar account is invalid: strkey is 0 bytes long; minimum valid length is 5") + + // invalid subject + claims.Subject = "invalid" + err = claims.Valid() + require.EqualError(t, err, "stellar account is invalid: base32 decode failed: illegal base32 data at input byte 7") + + // invalid client domain + claims.Subject = "GB54GWWWOSHATX5ALKHBBL2IQBZ2E7TBFO7F7VXKPIW6XANYDK4Y3RRC:123456" + claims.ClientDomainClaim = "localhost:8000" + err = claims.Valid() + require.EqualError(t, err, `client_domain is invalid: "localhost:8000" is not a valid DNS name`) + + // valid claims πŸŽ‰ + claims.ClientDomainClaim = "test.com" + err = claims.Valid() + require.NoError(t, err) +} diff --git a/internal/crashtracker/crash_tracker_client.go b/internal/crashtracker/crash_tracker_client.go new file mode 100644 index 000000000..803914f2e --- /dev/null +++ b/internal/crashtracker/crash_tracker_client.go @@ -0,0 +1,14 @@ +package crashtracker + +import ( + "context" + "time" +) + +type CrashTrackerClient interface { + LogAndReportErrors(ctx context.Context, err error, msg string) + LogAndReportMessages(ctx context.Context, msg string) + FlushEvents(waitTime time.Duration) bool + Recover() + Clone() CrashTrackerClient +} diff --git a/internal/crashtracker/dry_run_client.go b/internal/crashtracker/dry_run_client.go new file mode 100644 index 000000000..edc75f63c --- /dev/null +++ b/internal/crashtracker/dry_run_client.go @@ -0,0 +1,39 @@ +package crashtracker + +import ( + "context" + "fmt" + "time" + + "github.com/stellar/go/support/log" +) + +type dryRunClient struct{} + +func (s *dryRunClient) LogAndReportErrors(ctx context.Context, err error, msg string) { + if msg != "" { + err = fmt.Errorf("%s: %w", msg, err) + } + log.Ctx(ctx).Errorf("%+v", err) +} + +func (s *dryRunClient) LogAndReportMessages(ctx context.Context, msg string) { + log.Ctx(ctx).Info(msg) +} + +func (s *dryRunClient) FlushEvents(waitTime time.Duration) bool { + return false +} + +func (s *dryRunClient) Recover() {} + +func (s *dryRunClient) Clone() CrashTrackerClient { + return &dryRunClient{} +} + +func NewDryRunClient() (*dryRunClient, error) { + return &dryRunClient{}, nil +} + +// Ensuring that dryRunClient is implementing CrashTrackerClient interface +var _ CrashTrackerClient = (*dryRunClient)(nil) diff --git a/internal/crashtracker/dry_run_client_test.go b/internal/crashtracker/dry_run_client_test.go new file mode 100644 index 000000000..ea1fb7538 --- /dev/null +++ b/internal/crashtracker/dry_run_client_test.go @@ -0,0 +1,82 @@ +package crashtracker + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DryRun_LogAndReportErrors(t *testing.T) { + mDryRunClient := &dryRunClient{} + mMsgError := "error" + mError := fmt.Errorf("mock error") + ctx := context.Background() + + t.Run("LogAndReportErrors without message", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + mDryRunClient.LogAndReportErrors(ctx, mError, mMsgError) + + // validate logs + require.Contains(t, buf.String(), "error: mock error") + }) + + t.Run("LogAndReportErrors with message", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + mDryRunClient.LogAndReportErrors(ctx, mError, mMsgError) + + // validate logs + require.Contains(t, buf.String(), "mock error") + }) +} + +func Test_DryRun_LogAndReportMessages(t *testing.T) { + mDryRunClient := &dryRunClient{} + mMsg := "mock message" + + t.Run("LogAndReportMessages without message", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.InfoLevel) + + mDryRunClient.LogAndReportMessages(context.Background(), mMsg) + + // validate logs + require.Contains(t, buf.String(), "mock message") + }) +} + +func Test_DryRun_FlushEvents(t *testing.T) { + mDryRunClient := &dryRunClient{} + + waitTimeout := time.Second + valid := mDryRunClient.FlushEvents(waitTimeout) + + assert.Equal(t, false, valid) +} + +func Test_DryRun_Clone(t *testing.T) { + mDryRunClient := &dryRunClient{} + + waitTimeout := time.Second + valid := mDryRunClient.FlushEvents(waitTimeout) + + assert.Equal(t, false, valid) + + cloneClient := mDryRunClient.Clone() + + assert.IsType(t, &dryRunClient{}, cloneClient) + assert.NotEqual(t, mDryRunClient, &cloneClient) +} diff --git a/internal/crashtracker/main.go b/internal/crashtracker/main.go new file mode 100644 index 000000000..c4dfd87a6 --- /dev/null +++ b/internal/crashtracker/main.go @@ -0,0 +1,53 @@ +package crashtracker + +import ( + "context" + "fmt" + "strings" + + "github.com/stellar/go/support/log" +) + +type CrashTrackerType string + +const ( + // CrashTrackerTypeSentry is used to monitor errors with sentry. + CrashTrackerTypeSentry CrashTrackerType = "SENTRY" + // CrashTrackerTypeDryRun is used for development environment + CrashTrackerTypeDryRun CrashTrackerType = "DRY_RUN" +) + +func ParseCrashTrackerType(messengerTypeStr string) (CrashTrackerType, error) { + crashTrackerTypeStrUpper := strings.ToUpper(messengerTypeStr) + ctType := CrashTrackerType(crashTrackerTypeStrUpper) + + switch ctType { + case CrashTrackerTypeSentry, CrashTrackerTypeDryRun: + return ctType, nil + default: + return "", fmt.Errorf("invalid crash tracker type %q", crashTrackerTypeStrUpper) + } +} + +type CrashTrackerOptions struct { + CrashTrackerType CrashTrackerType + Environment string + GitCommit string + + // Sentry variables + SentryDSN string +} + +func GetClient(ctx context.Context, opts CrashTrackerOptions) (CrashTrackerClient, error) { + switch opts.CrashTrackerType { + case CrashTrackerTypeSentry: + log.Ctx(ctx).Infof("Using %q crash tracker", opts.CrashTrackerType) + return NewSentryClient(opts.SentryDSN, opts.Environment, opts.GitCommit) + case CrashTrackerTypeDryRun: + log.Ctx(ctx).Warnf("Using %q crash tracker", opts.CrashTrackerType) + return NewDryRunClient() + + default: + return nil, fmt.Errorf("unknown crash tracker type: %q", opts.CrashTrackerType) + } +} diff --git a/internal/crashtracker/main_test.go b/internal/crashtracker/main_test.go new file mode 100644 index 000000000..80cddab35 --- /dev/null +++ b/internal/crashtracker/main_test.go @@ -0,0 +1,60 @@ +package crashtracker + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ParseCrashTrackerType(t *testing.T) { + testCases := []struct { + metricTypeStr string + expectedCrashTrackerType CrashTrackerType + wantErr error + }{ + {wantErr: fmt.Errorf("invalid crash tracker type \"\"")}, + {metricTypeStr: "MOCKCRASHTRACKERTYPE", wantErr: fmt.Errorf("invalid crash tracker type \"MOCKCRASHTRACKERTYPE\"")}, + {metricTypeStr: "sentry", expectedCrashTrackerType: CrashTrackerTypeSentry}, + {metricTypeStr: "SENtry", expectedCrashTrackerType: CrashTrackerTypeSentry}, + {metricTypeStr: "DRY_run", expectedCrashTrackerType: CrashTrackerTypeDryRun}, + {metricTypeStr: "dry_run", expectedCrashTrackerType: CrashTrackerTypeDryRun}, + } + for _, tc := range testCases { + t.Run("crashTrackerType: "+tc.metricTypeStr, func(t *testing.T) { + crashTrackerType, err := ParseCrashTrackerType(tc.metricTypeStr) + assert.Equal(t, tc.expectedCrashTrackerType, crashTrackerType) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func Test_GetClient(t *testing.T) { + ctx := context.Background() + crashTrackerOptions := CrashTrackerOptions{} + + t.Run("get sentry crash tracker client", func(t *testing.T) { + crashTrackerOptions.CrashTrackerType = CrashTrackerTypeSentry + + gotClient, err := GetClient(ctx, crashTrackerOptions) + assert.NoError(t, err) + assert.IsType(t, &sentryClient{}, gotClient) + }) + + t.Run("get dry run crash tracker client", func(t *testing.T) { + crashTrackerOptions.CrashTrackerType = CrashTrackerTypeDryRun + + gotClient, err := GetClient(ctx, crashTrackerOptions) + assert.NoError(t, err) + assert.IsType(t, &dryRunClient{}, gotClient) + }) + + t.Run("error metric passed is invalid", func(t *testing.T) { + crashTrackerOptions.CrashTrackerType = "MOCKCRASHTRACKERTYPE" + + gotClient, err := GetClient(ctx, crashTrackerOptions) + assert.Nil(t, gotClient) + assert.EqualError(t, err, "unknown crash tracker type: \"MOCKCRASHTRACKERTYPE\"") + }) +} diff --git a/internal/crashtracker/mocks.go b/internal/crashtracker/mocks.go new file mode 100644 index 000000000..1a0f32acf --- /dev/null +++ b/internal/crashtracker/mocks.go @@ -0,0 +1,35 @@ +package crashtracker + +import ( + "context" + "time" + + "github.com/stretchr/testify/mock" +) + +type MockCrashTrackerClient struct { + mock.Mock +} + +func (m *MockCrashTrackerClient) LogAndReportErrors(ctx context.Context, err error, msg string) { + m.Called(ctx, err, msg) +} + +func (m *MockCrashTrackerClient) LogAndReportMessages(ctx context.Context, msg string) { + m.Called(ctx, msg) +} + +func (m *MockCrashTrackerClient) FlushEvents(waitTime time.Duration) bool { + return m.Called(waitTime).Get(0).(bool) +} + +func (m *MockCrashTrackerClient) Recover() { + m.Called() +} + +func (m *MockCrashTrackerClient) Clone() CrashTrackerClient { + return m.Called().Get(0).(*MockCrashTrackerClient) +} + +// Ensuring that MockCrashTrackerClient is implementing CrashTrackerClient interface +var _ CrashTrackerClient = (*MockCrashTrackerClient)(nil) diff --git a/internal/crashtracker/sentry_client.go b/internal/crashtracker/sentry_client.go new file mode 100644 index 000000000..0faa3ce45 --- /dev/null +++ b/internal/crashtracker/sentry_client.go @@ -0,0 +1,109 @@ +package crashtracker + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/getsentry/sentry-go" + "github.com/stellar/go/support/log" +) + +type hubSentryInterface interface { + CaptureException(exception error) *sentry.EventID + CaptureMessage(message string) *sentry.EventID + Clone() *sentry.Hub + Flush(timeout time.Duration) bool + Recover(err interface{}) *sentry.EventID +} + +// Ensuring that *sentry.Hub is implementing hubSentryInterface interface. +var _ hubSentryInterface = (*sentry.Hub)(nil) + +type sentryInterface interface { + Init(options sentry.ClientOptions) error + GetHubFromContext(ctx context.Context) hubSentryInterface + CurrentHub() hubSentryInterface +} + +// sentryImplementation implements the sentry interface methods using the sentry module. +type sentryImplementation struct{} + +func (s *sentryImplementation) Init(options sentry.ClientOptions) error { + return sentry.Init(options) +} + +func (s *sentryImplementation) GetHubFromContext(ctx context.Context) hubSentryInterface { + return sentry.GetHubFromContext(ctx) +} + +func (s *sentryImplementation) CurrentHub() hubSentryInterface { + return sentry.CurrentHub() +} + +// Ensuring that *sentryImplementation is implementing sentryInterface interface. +var _ sentryInterface = (*sentryImplementation)(nil) + +type sentryClient struct { + hub hubSentryInterface + sentryImplementation sentryInterface +} + +// LogAndReportErrors is a method responsible to receive a err and a message and log this info before capture the exception with sentry. +func (s *sentryClient) LogAndReportErrors(ctx context.Context, err error, msg string) { + // check if error is context canceled: + if errors.Is(err, context.Canceled) { + log.Warn("context canceled, not reporting error to sentry") + return + } + + if msg != "" { + err = fmt.Errorf("%s: %w", msg, err) + } + log.Ctx(ctx).WithStack(err).Errorf("%+v", err) + s.hub.CaptureException(err) +} + +// LogAndReportMessages is a method responsible to receive a message and log this info before capture a message with sentry. +func (s *sentryClient) LogAndReportMessages(ctx context.Context, msg string) { + log.Ctx(ctx).Info(msg) + s.hub.CaptureMessage(msg) +} + +// FlushEvents is a method that implements a timeout for events to be dispatched after an application terminates. +func (s *sentryClient) FlushEvents(waitTime time.Duration) bool { + return s.hub.Flush(waitTime) +} + +// Recover is a method that capture unhandled panics. +func (s *sentryClient) Recover() { + if err := recover(); err != nil { + s.hub.Recover(err) + } +} + +// Clone is a method that clones a new CrashTrackerClient to be used in concurrent routines. +func (s *sentryClient) Clone() CrashTrackerClient { + cloneHub := s.hub.Clone() + return &sentryClient{hub: cloneHub} +} + +// NewSentryClient is a func that creates a new sentryClient using the sentryImplementation. +func NewSentryClient(sentryDSN string, environment string, gitCommit string) (*sentryClient, error) { + si := &sentryImplementation{} + err := si.Init(sentry.ClientOptions{ + Dsn: sentryDSN, + Release: gitCommit, + Environment: environment, + }) + if err != nil { + return nil, fmt.Errorf("error setting up Sentry: %w", err) + } + + hub := si.CurrentHub() + return &sentryClient{hub: hub, sentryImplementation: si}, nil +} + +// Ensuring that sentryClient is implementing CrashTrackerClient interface +var _ CrashTrackerClient = (*sentryClient)(nil) diff --git a/internal/crashtracker/sentry_client_test.go b/internal/crashtracker/sentry_client_test.go new file mode 100644 index 000000000..8bb1c3cee --- /dev/null +++ b/internal/crashtracker/sentry_client_test.go @@ -0,0 +1,172 @@ +package crashtracker + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/getsentry/sentry-go" + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockHubSentry struct { + mock.Mock +} + +func (m *mockHubSentry) CaptureException(exception error) *sentry.EventID { + return m.Called(exception).Get(0).(*sentry.EventID) +} + +func (m *mockHubSentry) CaptureMessage(message string) *sentry.EventID { + return m.Called(message).Get(0).(*sentry.EventID) +} + +func (m *mockHubSentry) Clone() *sentry.Hub { + return m.Called().Get(0).(*sentry.Hub) +} + +func (m *mockHubSentry) Flush(timeout time.Duration) bool { + return m.Called(timeout).Get(0).(bool) +} + +func (m *mockHubSentry) Recover(err interface{}) *sentry.EventID { + return m.Called(err).Get(0).(*sentry.EventID) +} + +// Ensuring that mockSentry is implementing sentryInterface interface +var _ hubSentryInterface = (*mockHubSentry)(nil) + +type mockSentry struct { + mock.Mock +} + +func (m *mockSentry) Init(options sentry.ClientOptions) error { + return m.Called(options).Error(0) +} + +func (m *mockSentry) GetHubFromContext(ctx context.Context) hubSentryInterface { + return m.Called(ctx).Get(0).(*mockHubSentry) +} + +func (m *mockSentry) CurrentHub() hubSentryInterface { + return m.Called().Get(0).(*mockHubSentry) +} + +// Ensuring that *mockSentry is implementing sentryInterface interface. +var _ sentryInterface = (*mockSentry)(nil) + +func Test_SentryClient_LogAndReportErrors(t *testing.T) { + mHubSentry := &mockHubSentry{} + + mSentryClient := &sentryClient{ + hub: mHubSentry, + } + mMsgError := "error" + mError := fmt.Errorf("mock error") + ctx := context.Background() + + t.Run("LogAndReportErrors without message", func(t *testing.T) { + e := fmt.Errorf("%s: %w", mMsgError, mError) + sentryId := sentry.EventID("id-1") + + mHubSentry.On("CaptureException", e).Return(&sentryId).Once() + mSentryClient.LogAndReportErrors(ctx, mError, mMsgError) + + mHubSentry.AssertExpectations(t) + }) + + t.Run("LogAndReportErrors with message", func(t *testing.T) { + mMsgError = "" + sentryId := sentry.EventID("id-1") + + mHubSentry.On("CaptureException", mError).Return(&sentryId).Once() + mSentryClient.LogAndReportErrors(ctx, mError, mMsgError) + + mHubSentry.AssertExpectations(t) + }) + + t.Run("LogAndReportErrors ignores context.Canceled", func(t *testing.T) { + mHubSentry = &mockHubSentry{} + mSentryClient = &sentryClient{hub: mHubSentry} + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + err := fmt.Errorf("external error that wraps: %w", context.Canceled) + mSentryClient.LogAndReportErrors(ctx, err, mMsgError) + mHubSentry.AssertNotCalled(t, "CaptureException", mock.Anything) + + require.Contains(t, buf.String(), "context canceled, not reporting error to sentry") + }) +} + +func Test_SentryClient_LogAndReportMessages(t *testing.T) { + mHubSentry := &mockHubSentry{} + + mSentryClient := &sentryClient{ + hub: mHubSentry, + } + mMsgError := "crash error" + + sentryId := sentry.EventID("id-1") + + mHubSentry.On("CaptureMessage", mMsgError).Return(&sentryId).Once() + mSentryClient.LogAndReportMessages(context.Background(), mMsgError) + + mHubSentry.AssertExpectations(t) +} + +func Test_SentryClient_FlushEvents(t *testing.T) { + mHubSentry := &mockHubSentry{} + + mSentryClient := &sentryClient{ + hub: mHubSentry, + } + waitTimeout := time.Second + + mHubSentry.On("Flush", waitTimeout).Return(true).Once() + mSentryClient.FlushEvents(waitTimeout) + + mHubSentry.AssertExpectations(t) +} + +func Test_SentryClient_Recover(t *testing.T) { + mHubSentry := &mockHubSentry{} + + mSentryClient := &sentryClient{ + hub: mHubSentry, + } + + mockErr := fmt.Errorf("error test") + sentryId := sentry.EventID("id-1") + + mHubSentry.On("Recover", mockErr).Return(&sentryId).Once() + + defer mHubSentry.AssertExpectations(t) + defer mSentryClient.Recover() + + panic(mockErr) +} + +func Test_SentryClient_Clone(t *testing.T) { + mHubSentry := &mockHubSentry{} + + mSentryClient := &sentryClient{ + hub: mHubSentry, + } + + hub := sentry.Hub{} + mHubSentry.On("Clone").Return(&hub).Once() + + cloneClient := mSentryClient.Clone() + + sc := cloneClient.(*sentryClient) + assert.Equal(t, &hub, sc.hub) + + mHubSentry.AssertExpectations(t) +} diff --git a/internal/data/assets.go b/internal/data/assets.go new file mode 100644 index 000000000..1c93793aa --- /dev/null +++ b/internal/data/assets.go @@ -0,0 +1,235 @@ +package data + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type Asset struct { + ID string `json:"id" db:"id"` + Code string `json:"code" db:"code"` + Issuer string `json:"issuer" db:"issuer"` + CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"` + UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"` + DeletedAt *time.Time `json:"deleted_at" db:"deleted_at"` +} + +type AssetModel struct { + dbConnectionPool db.DBConnectionPool +} + +func (a *AssetModel) Get(ctx context.Context, id string) (*Asset, error) { + var asset Asset + query := ` + SELECT + a.id, + a.code, + a.issuer, + a.created_at, + a.updated_at, + a.deleted_at + FROM + assets a + WHERE + a.id = $1 + ` + + err := a.dbConnectionPool.GetContext(ctx, &asset, query, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying asset ID %s: %w", id, err) + } + return &asset, nil +} + +// GetByCodeAndIssuer returns asset filtering by code and issuer. +func (a *AssetModel) GetByCodeAndIssuer(ctx context.Context, code, issuer string) (*Asset, error) { + var asset Asset + query := ` + SELECT + a.id, + a.code, + a.issuer, + a.created_at, + a.updated_at, + a.deleted_at + FROM + assets a + WHERE a.code = $1 + AND a.issuer = $2 + ` + + err := a.dbConnectionPool.GetContext(ctx, &asset, query, code, issuer) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying asset with code %s and issuer %s: %w", code, issuer, err) + } + return &asset, nil +} + +// GetAll returns all assets in the database. +func (a *AssetModel) GetAll(ctx context.Context) ([]Asset, error) { + // TODO: We will want to filter out "deleted" assets at some point + assets := []Asset{} + query := ` + SELECT + a.id, + a.code, + a.issuer, + a.created_at, + a.updated_at, + a.deleted_at + FROM + assets a + ORDER BY + a.code ASC + ` + + err := a.dbConnectionPool.SelectContext(ctx, &assets, query) + if err != nil { + return nil, fmt.Errorf("error querying assets: %w", err) + } + return assets, nil +} + +func (a *AssetModel) Insert(ctx context.Context, sqlExec db.SQLExecuter, code string, issuer string) (*Asset, error) { + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ON CONFLICT (code, issuer) DO + UPDATE SET + deleted_at = NULL + WHERE + assets.deleted_at IS NOT NULL + RETURNING * + ` + + var asset Asset + err := sqlExec.GetContext(ctx, &asset, query, code, issuer) + if err != nil { + return nil, fmt.Errorf("error inserting asset: %w", err) + } + + return &asset, nil +} + +func (a *AssetModel) GetOrCreate(ctx context.Context, code, issuer string) (*Asset, error) { + const query = ` + WITH create_asset AS( + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ON CONFLICT (code, issuer) DO NOTHING + RETURNING * + ) + SELECT * FROM create_asset ca + UNION ALL + SELECT * FROM assets a + WHERE a.code = $1 + AND a.issuer = $2 + ` + + var asset Asset + err := a.dbConnectionPool.GetContext(ctx, &asset, query, code, issuer) + if err != nil { + return nil, fmt.Errorf("error getting or creating asset: %w", err) + } + + return &asset, nil +} + +func (a *AssetModel) SoftDelete(ctx context.Context, sqlExec db.SQLExecuter, id string) (*Asset, error) { + query := ` + UPDATE + assets + SET + deleted_at = NOW() + WHERE id = $1 + RETURNING * + ` + + var asset Asset + err := sqlExec.GetContext(ctx, &asset, query, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error soft deleting asset ID %s: %w", id, err) + } + return &asset, nil +} + +type ReceiverWalletAsset struct { + WalletID string `db:"wallet_id"` + ReceiverWallet ReceiverWallet `db:"receiver_wallet"` + Asset Asset `db:"asset"` +} + +// GetAssetsPerReceiverWallet returns the assets associated with a READY payment for each receiver +// wallet provided. +func (a *AssetModel) GetAssetsPerReceiverWallet(ctx context.Context, receiverWallets ...*ReceiverWallet) ([]ReceiverWalletAsset, error) { + receiverWalletIDs := make([]string, len(receiverWallets)) + for i, rw := range receiverWallets { + receiverWalletIDs[i] = rw.ID + } + + var receiverWalletsAssets []ReceiverWalletAsset + query := ` + WITH latest_payments_by_wallet AS ( + -- Gets the latest payment by wallet with its asset + SELECT + p.id AS payment_id, + d.wallet_id, + p.asset_id + FROM + payments p + INNER JOIN disbursements d ON d.id = p.disbursement_id + INNER JOIN assets a ON a.id = p.asset_id + WHERE + p.status = $1 + GROUP BY + p.id, p.asset_id, d.wallet_id + ORDER BY + p.updated_at DESC + ) + SELECT DISTINCT + lpw.wallet_id, + rw.id AS "receiver_wallet.id", + r.id AS "receiver_wallet.receiver.id", + r.phone_number AS "receiver_wallet.receiver.phone_number", + r.email AS "receiver_wallet.receiver.email", + a.id AS "asset.id", + a.code AS "asset.code", + a.issuer AS "asset.issuer", + a.created_at AS "asset.created_at", + a.updated_at AS "asset.updated_at" + FROM + assets a + INNER JOIN latest_payments_by_wallet lpw ON lpw.asset_id = a.id + INNER JOIN payments p ON p.id = lpw.payment_id + INNER JOIN receiver_wallets rw ON rw.id = p.receiver_wallet_id + INNER JOIN receivers r ON r.id = rw.receiver_id + WHERE + rw.id = ANY($2) + ` + + err := a.dbConnectionPool.SelectContext(ctx, &receiverWalletsAssets, query, ReadyPaymentStatus, pq.Array(receiverWalletIDs)) + if err != nil { + return nil, fmt.Errorf("error querying most recent asset per receiver wallet: %w", err) + } + + return receiverWalletsAssets, nil +} diff --git a/internal/data/assets_test.go b/internal/data/assets_test.go new file mode 100644 index 000000000..558b40d1c --- /dev/null +++ b/internal/data/assets_test.go @@ -0,0 +1,497 @@ +package data + +import ( + "context" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_AssetModelGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + assetModel := &AssetModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when asset is not found", func(t *testing.T) { + _, err := assetModel.Get(ctx, "not-found") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns asset successfully", func(t *testing.T) { + expected := CreateAssetFixture(t, ctx, dbConnectionPool.SqlxDB(), "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + actual, err := assetModel.Get(ctx, expected.ID) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) +} + +func Test_AssetModelGetByCodeAndIssuer(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + assetModel := &AssetModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when asset is not found", func(t *testing.T) { + _, err := assetModel.GetByCodeAndIssuer(ctx, "invalid_code", "invalid_issuer") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns asset successfully", func(t *testing.T) { + expected := CreateAssetFixture(t, ctx, dbConnectionPool.SqlxDB(), "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + actual, err := assetModel.GetByCodeAndIssuer(ctx, expected.Code, expected.Issuer) + require.NoError(t, err) + assert.Equal(t, expected, actual) + }) +} + +func Test_AssetModelGetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + assetModel := &AssetModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns all assets successfully", func(t *testing.T) { + expected := ClearAndCreateAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + actual, err := assetModel.GetAll(ctx) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) + + t.Run("returns empty array when no assets", func(t *testing.T) { + DeleteAllAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + actual, err := assetModel.GetAll(ctx) + require.NoError(t, err) + + assert.Equal(t, []Asset{}, actual) + }) +} + +func Test_AssetModelInsert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + assetModel := &AssetModel{dbConnectionPool: dbConnectionPool} + + t.Run("inserts asset successfully", func(t *testing.T) { + DeleteAllAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + code := "USDT" + issuer := "GBVHJTRLQRMIHRYTXZQOPVYCVVH7IRJN3DOFT7VC6U75CBWWBVDTWURG" + + asset, err := assetModel.Insert(ctx, dbConnectionPool, code, issuer) + require.NoError(t, err) + assert.NotNil(t, asset) + + insertedAsset, err := assetModel.Get(ctx, asset.ID) + require.NoError(t, err) + assert.NotNil(t, insertedAsset) + }) + + t.Run("re-create a deleted asset", func(t *testing.T) { + DeleteAllAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + code := "USDT" + issuer := "GBVHJTRLQRMIHRYTXZQOPVYCVVH7IRJN3DOFT7VC6U75CBWWBVDTWURG" + + usdt, err := assetModel.Insert(ctx, dbConnectionPool, code, issuer) + require.NoError(t, err) + assert.NotNil(t, usdt) + + usdc, err := assetModel.Insert(ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + require.NoError(t, err) + assert.NotNil(t, usdt) + + _, err = assetModel.SoftDelete(ctx, dbConnectionPool, usdc.ID) + require.NoError(t, err) + + _, err = assetModel.SoftDelete(ctx, dbConnectionPool, usdt.ID) + require.NoError(t, err) + + usdcDB, err := assetModel.Get(ctx, usdc.ID) + require.NoError(t, err) + assert.NotNil(t, usdcDB.DeletedAt) + + reCreatedUSDT, err := assetModel.Insert(ctx, dbConnectionPool, code, issuer) + require.NoError(t, err) + assert.NotNil(t, reCreatedUSDT) + + usdtDB, err := assetModel.Get(ctx, usdt.ID) + require.NoError(t, err) + + assert.NotNil(t, usdtDB) + + assert.Equal(t, usdtDB.ID, usdt.ID) + assert.Equal(t, usdtDB.Code, usdt.Code) + assert.Equal(t, usdtDB.Issuer, usdt.Issuer) + + assert.Equal(t, usdtDB.ID, reCreatedUSDT.ID) + assert.Equal(t, usdtDB.Code, reCreatedUSDT.Code) + assert.Equal(t, usdtDB.Issuer, reCreatedUSDT.Issuer) + + usdcDB, err = assetModel.Get(ctx, usdc.ID) + require.NoError(t, err) + assert.NotNil(t, usdcDB.DeletedAt) + }) + + t.Run("does not insert the same asset again", func(t *testing.T) { + DeleteAllAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + code := "USDT" + issuer := "GBVHJTRLQRMIHRYTXZQOPVYCVVH7IRJN3DOFT7VC6U75CBWWBVDTWURG" + + asset, err := assetModel.Insert(ctx, dbConnectionPool, code, issuer) + require.NoError(t, err) + assert.NotNil(t, asset) + + duplicatedAsset, err := assetModel.Insert(ctx, dbConnectionPool, code, issuer) + assert.EqualError(t, err, "error inserting asset: sql: no rows in result set") + assert.Nil(t, duplicatedAsset) + }) +} + +func Test_AssetModelGetOrCreate(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + assetModel := &AssetModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when issuer is invalid", func(t *testing.T) { + asset, err := assetModel.GetOrCreate(ctx, "FOO1", "invalid_issuer") + require.EqualError(t, err, "error getting or creating asset: pq: new row for relation \"assets\" violates check constraint \"asset_issuer_length_check\"") + assert.Empty(t, asset) + }) + + t.Run("creates asset successfully", func(t *testing.T) { + asset, err := assetModel.GetOrCreate(ctx, "F001", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + require.NoError(t, err) + assert.Equal(t, "F001", asset.Code) + assert.Equal(t, "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", asset.Issuer) + }) + + t.Run("returns asset successfully", func(t *testing.T) { + expected := CreateAssetFixture(t, ctx, dbConnectionPool.SqlxDB(), "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + asset, err := assetModel.GetOrCreate(ctx, expected.Code, expected.Issuer) + require.NoError(t, err) + assert.Equal(t, expected.ID, asset.ID) + }) +} + +func Test_AssetModelSoftDelete(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + assetModel := &AssetModel{dbConnectionPool: dbConnectionPool} + + t.Run("delete successful", func(t *testing.T) { + DeleteAllAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + expected := CreateAssetFixture(t, ctx, dbConnectionPool.SqlxDB(), "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + + asset, err := assetModel.SoftDelete(ctx, dbConnectionPool, expected.ID) + require.NoError(t, err) + assert.NotNil(t, asset) + assert.NotNil(t, asset.DeletedAt) + deletedAt := asset.DeletedAt + + deletedAsset, err := assetModel.Get(ctx, expected.ID) + require.NoError(t, err) + assert.NotNil(t, deletedAsset) + assert.Equal(t, deletedAsset.DeletedAt, deletedAt) + }) + + t.Run("delete unsuccessful, cannot find asset", func(t *testing.T) { + DeleteAllAssetFixtures(t, ctx, dbConnectionPool.SqlxDB()) + + _, err := assetModel.SoftDelete(ctx, dbConnectionPool, "non-existant") + require.Error(t, err) + }) +} + +func Test_GetAssetsPerReceiverWallet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + models, err := NewModels(dbConnectionPool) + require.NoError(t, err) + + // 1. Create assets, wallets and disbursements: + country := CreateCountryFixture(t, ctx, dbConnectionPool, "ATL", "Atlantis") + + asset1 := CreateAssetFixture(t, ctx, dbConnectionPool, "FOO1", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + asset2 := CreateAssetFixture(t, ctx, dbConnectionPool, "FOO2", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + + walletA := CreateWalletFixture(t, ctx, dbConnectionPool, "walletA", "https://www.a.com", "www.a.com", "a://") + walletB := CreateWalletFixture(t, ctx, dbConnectionPool, "walletB", "https://www.b.com", "www.b.com", "b://") + + disbursementA1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: walletA, + Status: ReadyDisbursementStatus, + Asset: asset1, + }) + disbursementA2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: walletA, + Status: ReadyDisbursementStatus, + Asset: asset2, + }) + disbursementB1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: walletB, + Status: ReadyDisbursementStatus, + Asset: asset1, + }) + disbursementB2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: walletB, + Status: ReadyDisbursementStatus, + Asset: asset2, + }) + + // 2. Create receivers, and receiver wallets: + receiverX := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverY := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + receiverWalletXA := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverX.ID, walletA.ID, DraftReceiversWalletStatus) + receiverWalletXB := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverX.ID, walletB.ID, DraftReceiversWalletStatus) + receiverWalletYA := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverY.ID, walletA.ID, DraftReceiversWalletStatus) + receiverWalletYB := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverY.ID, walletB.ID, DraftReceiversWalletStatus) + + // 3. Create payments: + // paymentXA1 - walletA, asset1 for receiverX on their receiverWalletA + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletXA, + Disbursement: disbursementA1, + Asset: *asset1, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + // paymentXA2 - walletA, asset2 for receiverX on their receiverWalletA + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletXA, + Disbursement: disbursementA2, + Asset: *asset2, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + // paymentXA2 - walletA, asset2 for receiverX on their receiverWalletA - This should be ignored + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletXA, + Disbursement: disbursementA2, + Asset: *asset2, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + // paymentXB2 - walletB, asset2 for receiverX on their receiverWalletB + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletXB, + Disbursement: disbursementB2, + Asset: *asset2, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + // paymentXB1 - walletB, asset1 for receiverX on their receiverWalletB + time.Sleep(10 * time.Millisecond) + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletXB, + Disbursement: disbursementB1, + Asset: *asset1, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + // paymentYA2 - walletA, asset2 for receiverY on their receiverWalletA + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletYA, + Disbursement: disbursementA2, + Asset: *asset2, + Status: ReadyPaymentStatus, + UpdatedAt: time.Date(2024, 1, 6, 0, 0, 0, 0, time.UTC), + Amount: "1", + }) + + // paymentYA1 - walletA, asset1 for receiverY on their receiverWalletA + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletYA, + Disbursement: disbursementA1, + Asset: *asset1, + Status: ReadyPaymentStatus, + UpdatedAt: time.Date(2024, 2, 5, 0, 0, 0, 0, time.UTC), + Amount: "1", + }) + + // paymentYB1 - walletB, asset1 for receiverY on their receiverWalletB + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletYB, + Disbursement: disbursementB1, + Asset: *asset1, + Status: ReadyPaymentStatus, + UpdatedAt: time.Date(2024, 1, 7, 0, 0, 0, 0, time.UTC), + Amount: "1", + }) + + // paymentYB2 - walletB, asset2 for receiverY on their receiverWalletB + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletYB, + Disbursement: disbursementB2, + Asset: *asset2, + Status: ReadyPaymentStatus, + UpdatedAt: time.Date(2024, 1, 8, 0, 0, 0, 0, time.UTC), + Amount: "1", + }) + + gotLatestAssetsPerRW, err := models.Assets.GetAssetsPerReceiverWallet(ctx, receiverWalletXA, receiverWalletXB, receiverWalletYA, receiverWalletYB) + require.NoError(t, err) + require.Len(t, gotLatestAssetsPerRW, 8) + + wantLatestAssetsPerRW := []ReceiverWalletAsset{ + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletXA.ID, + Receiver: Receiver{ + ID: receiverX.ID, + Email: receiverX.Email, + PhoneNumber: receiverX.PhoneNumber, + }, + }, + WalletID: walletA.ID, + Asset: *asset1, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletXA.ID, + Receiver: Receiver{ + ID: receiverX.ID, + Email: receiverX.Email, + PhoneNumber: receiverX.PhoneNumber, + }, + }, + WalletID: walletA.ID, + Asset: *asset2, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletXB.ID, + Receiver: Receiver{ + ID: receiverX.ID, + Email: receiverX.Email, + PhoneNumber: receiverX.PhoneNumber, + }, + }, + WalletID: walletB.ID, + Asset: *asset1, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletXB.ID, + Receiver: Receiver{ + ID: receiverX.ID, + Email: receiverX.Email, + PhoneNumber: receiverX.PhoneNumber, + }, + }, + WalletID: walletB.ID, + Asset: *asset2, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletYA.ID, + Receiver: Receiver{ + ID: receiverY.ID, + Email: receiverY.Email, + PhoneNumber: receiverY.PhoneNumber, + }, + }, + WalletID: walletA.ID, + Asset: *asset1, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletYA.ID, + Receiver: Receiver{ + ID: receiverY.ID, + Email: receiverY.Email, + PhoneNumber: receiverY.PhoneNumber, + }, + }, + WalletID: walletA.ID, + Asset: *asset2, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletYB.ID, + Receiver: Receiver{ + ID: receiverY.ID, + Email: receiverY.Email, + PhoneNumber: receiverY.PhoneNumber, + }, + }, + WalletID: walletB.ID, + Asset: *asset1, + }, + { + ReceiverWallet: ReceiverWallet{ + ID: receiverWalletYB.ID, + Receiver: Receiver{ + ID: receiverY.ID, + Email: receiverY.Email, + PhoneNumber: receiverY.PhoneNumber, + }, + }, + WalletID: walletB.ID, + Asset: *asset2, + }, + } + + assert.ElementsMatch(t, wantLatestAssetsPerRW, gotLatestAssetsPerRW) +} diff --git a/internal/data/countries.go b/internal/data/countries.go new file mode 100644 index 000000000..d505eff3d --- /dev/null +++ b/internal/data/countries.go @@ -0,0 +1,69 @@ +package data + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type Country struct { + Code string `json:"code" db:"code"` + Name string `json:"name" db:"name"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + DeletedAt *time.Time `json:"-" db:"deleted_at"` +} + +type CountryModel struct { + dbConnectionPool db.DBConnectionPool +} + +func (m *CountryModel) Get(ctx context.Context, code string) (*Country, error) { + var country Country + query := ` + SELECT + c.code, + c.name, + c.created_at, + c.updated_at + FROM + countries c + WHERE + c.code = $1 + ` + + err := m.dbConnectionPool.GetContext(ctx, &country, query, code) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying country code %s: %w", code, err) + } + return &country, nil +} + +// GetAll returns all countries in the database +func (m *CountryModel) GetAll(ctx context.Context) ([]Country, error) { + countries := []Country{} + query := ` + SELECT + c.code, + c.name, + c.created_at, + c.updated_at + FROM + countries c + ORDER BY + c.name ASC + ` + + err := m.dbConnectionPool.SelectContext(ctx, &countries, query) + if err != nil { + return nil, fmt.Errorf("error querying countries: %w", err) + } + return countries, nil +} diff --git a/internal/data/countries_test.go b/internal/data/countries_test.go new file mode 100644 index 000000000..d688ae0a7 --- /dev/null +++ b/internal/data/countries_test.go @@ -0,0 +1,65 @@ +package data + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CountryModelGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + countryModel := &CountryModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when country is not found", func(t *testing.T) { + _, err := countryModel.Get(ctx, "not-found") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns asset successfully", func(t *testing.T) { + expected := CreateCountryFixture(t, ctx, dbConnectionPool.SqlxDB(), "FRA", "France") + actual, err := countryModel.Get(ctx, "FRA") + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) +} + +func Test_CountryModelGetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + countryModel := &CountryModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns all countries successfully", func(t *testing.T) { + expected := ClearAndCreateCountryFixtures(t, ctx, dbConnectionPool.SqlxDB()) + actual, err := countryModel.GetAll(ctx) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) + + t.Run("returns empty array when no countries", func(t *testing.T) { + DeleteAllCountryFixtures(t, ctx, dbConnectionPool.SqlxDB()) + actual, err := countryModel.GetAll(ctx) + require.NoError(t, err) + + assert.Equal(t, []Country{}, actual) + }) +} diff --git a/internal/data/dibursements_state_machine.go b/internal/data/dibursements_state_machine.go new file mode 100644 index 000000000..7d0785601 --- /dev/null +++ b/internal/data/dibursements_state_machine.go @@ -0,0 +1,75 @@ +package data + +import ( + "fmt" + "strings" +) + +type DisbursementStatus string + +const ( + DraftDisbursementStatus DisbursementStatus = "DRAFT" + ReadyDisbursementStatus DisbursementStatus = "READY" + StartedDisbursementStatus DisbursementStatus = "STARTED" + PausedDisbursementStatus DisbursementStatus = "PAUSED" + CompletedDisbursementStatus DisbursementStatus = "COMPLETED" +) + +// TransitionTo transitions the disbursement status to the target state +func (status DisbursementStatus) TransitionTo(targetState DisbursementStatus) error { + return DisbursementStateMachineWithInitialState(status).TransitionTo(targetState.State()) +} + +// DisbursementStatuses returns a list of all possible disbursement statuses +func DisbursementStatuses() []DisbursementStatus { + return []DisbursementStatus{DraftDisbursementStatus, ReadyDisbursementStatus, StartedDisbursementStatus, PausedDisbursementStatus, CompletedDisbursementStatus} +} + +// SourceStatuses returns a list of states that the payment status can transition from given the target state +func (status DisbursementStatus) SourceStatuses() []DisbursementStatus { + stateMachine := DisbursementStateMachineWithInitialState(DraftDisbursementStatus) + fromStates := []DisbursementStatus{} + for _, fromState := range DisbursementStatuses() { + if stateMachine.Transitions[fromState.State()][status.State()] { + fromStates = append(fromStates, fromState) + } + } + return fromStates +} + +// DisbursementStateMachineWithInitialState returns a state machine for disbursements initialized with the given state +func DisbursementStateMachineWithInitialState(initialState DisbursementStatus) *StateMachine { + transitions := []StateTransition{ + {From: DraftDisbursementStatus.State(), To: ReadyDisbursementStatus.State()}, // instructions uploaded successfully + {From: ReadyDisbursementStatus.State(), To: ReadyDisbursementStatus.State()}, // user re-uploads instructions + {From: ReadyDisbursementStatus.State(), To: StartedDisbursementStatus.State()}, // user starts disbursement + {From: StartedDisbursementStatus.State(), To: PausedDisbursementStatus.State()}, // user pauses disbursement + {From: PausedDisbursementStatus.State(), To: StartedDisbursementStatus.State()}, // user resumes disbursement + {From: StartedDisbursementStatus.State(), To: CompletedDisbursementStatus.State()}, // all payments went through + } + + return NewStateMachine(initialState.State(), transitions) +} + +// Validate validates the disbursement status +func (status DisbursementStatus) Validate() error { + switch DisbursementStatus(strings.ToUpper(string(status))) { + case DraftDisbursementStatus, ReadyDisbursementStatus, StartedDisbursementStatus, PausedDisbursementStatus, CompletedDisbursementStatus: + return nil + default: + return fmt.Errorf("invalid disbursement status: %s", status) + } +} + +// ToDisbursementStatus converts a string to a DisbursementStatus +func ToDisbursementStatus(s string) (DisbursementStatus, error) { + err := DisbursementStatus(s).Validate() + if err != nil { + return "", err + } + return DisbursementStatus(strings.ToUpper(s)), nil +} + +func (status DisbursementStatus) State() State { + return State(status) +} diff --git a/internal/data/disbursement_instructions.go b/internal/data/disbursement_instructions.go new file mode 100644 index 000000000..151d9c1e1 --- /dev/null +++ b/internal/data/disbursement_instructions.go @@ -0,0 +1,213 @@ +package data + +import ( + "context" + "errors" + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type DisbursementInstruction struct { + Phone string `csv:"phone"` + ID string `csv:"id"` + Amount string `csv:"amount"` + VerificationValue string `csv:"verification"` +} + +type DisbursementInstructionModel struct { + dbConnectionPool db.DBConnectionPool + receiverVerificationModel *ReceiverVerificationModel + receiverWalletModel *ReceiverWalletModel + receiverModel *ReceiverModel + paymentModel *PaymentModel + disbursementModel *DisbursementModel +} + +const MaxInstructionsPerDisbursement = 10000 // TODO: update this number with load testing results [SDP-524] + +// NewDisbursementInstructionModel creates a new DisbursementInstructionModel. +func NewDisbursementInstructionModel(dbConnectionPool db.DBConnectionPool) *DisbursementInstructionModel { + return &DisbursementInstructionModel{ + dbConnectionPool: dbConnectionPool, + receiverVerificationModel: &ReceiverVerificationModel{}, + receiverWalletModel: &ReceiverWalletModel{dbConnectionPool: dbConnectionPool}, + receiverModel: &ReceiverModel{}, + paymentModel: &PaymentModel{dbConnectionPool: dbConnectionPool}, + disbursementModel: &DisbursementModel{dbConnectionPool: dbConnectionPool}, + } +} + +var ( + ErrMaxInstructionsExceeded = errors.New("maximum number of instructions exceeded") + ErrReceiverVerificationMismatch = errors.New("receiver verification mismatch") +) + +// ProcessAll Processes all disbursement instructions and persists the data to the database. +// +// |--- For each phone number in the instructions: +// | |--- Check if a receiver exists. +// | | |--- If a receiver does not exist, create one. +// | |--- For each receiver: +// | | |--- Check if the receiver verification exists. +// | | | |--- If the receiver verification does not exist, create one. +// | | | |--- If the receiver verification exists: +// | | | | |--- Check if the verification value matches. +// | | | | | |--- If the verification value does not match and the verification is confirmed, return an error. +// | | | | | |--- If the verification value does not match and the verification is not confirmed, update the verification value. +// | | | | | |--- If the verification value matches, continue. +// | | |--- Check if the receiver wallet exists. +// | | | |--- If the receiver wallet does not exist, create one. +// | | |--- Delete all payments tied to this disbursement. +// | | |--- Create all payments passed in the instructions. +func (di DisbursementInstructionModel) ProcessAll(ctx context.Context, userID string, instructions []*DisbursementInstruction, disbursement *Disbursement, update *DisbursementUpdate, maxNumberOfInstructions int) error { + if len(instructions) > maxNumberOfInstructions { + return ErrMaxInstructionsExceeded + } + + // We need all the following logic to be executed in one transaction. + return db.RunInTransaction(ctx, di.dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + // Step 1: Fetch all receivers by phone number and create missing ones + phoneNumbers := make([]string, 0, len(instructions)) + for _, instruction := range instructions { + phoneNumbers = append(phoneNumbers, instruction.Phone) + } + + existingReceivers, err := di.receiverModel.GetByPhoneNumbers(ctx, dbTx, phoneNumbers) + if err != nil { + return fmt.Errorf("error fetching receivers by phone number: %w", err) + } + + receiverMap := make(map[string]*Receiver) + for _, receiver := range existingReceivers { + receiverMap[receiver.PhoneNumber] = receiver + } + + instructionMap := make(map[string]*DisbursementInstruction) + for _, instruction := range instructions { + instructionMap[instruction.Phone] = instruction + } + + for _, instruction := range instructions { + _, exists := receiverMap[instruction.Phone] + if !exists { + receiverInsert := ReceiverInsert{ + PhoneNumber: instruction.Phone, + ExternalId: &instruction.ID, + } + receiver, insertErr := di.receiverModel.Insert(ctx, dbTx, receiverInsert) + if insertErr != nil { + return fmt.Errorf("error inserting receiver: %w", insertErr) + } + receiverMap[instruction.Phone] = receiver + } + } + + // Step 2: Fetch all receiver verifications and create missing ones. + receiverIDs := make([]string, 0, len(receiverMap)) + for _, receiver := range receiverMap { + receiverIDs = append(receiverIDs, receiver.ID) + } + verifications, err := di.receiverVerificationModel.GetByReceiverIdsAndVerificationField(ctx, dbTx, receiverIDs, disbursement.VerificationField) + if err != nil { + return fmt.Errorf("error fetching receiver verifications: %w", err) + } + + verificationMap := make(map[string]*ReceiverVerification) + for _, verification := range verifications { + verificationMap[verification.ReceiverID] = verification + } + + for _, receiver := range receiverMap { + verification, verificationExists := verificationMap[receiver.ID] + instruction := instructionMap[receiver.PhoneNumber] + if !verificationExists { + verificationInsert := ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationValue: instruction.VerificationValue, + VerificationField: disbursement.VerificationField, + } + hashedVerification, insertError := di.receiverVerificationModel.Insert(ctx, dbTx, verificationInsert) + if insertError != nil { + return fmt.Errorf("error inserting receiver verification: %w", insertError) + } + verificationMap[receiver.ID] = &ReceiverVerification{ + ReceiverID: verificationInsert.ReceiverID, + VerificationField: verificationInsert.VerificationField, + HashedValue: hashedVerification, + } + + } else { + if verified := CompareVerificationValue(verification.HashedValue, instruction.VerificationValue); !verified { + if verification.ConfirmedAt != nil { + return fmt.Errorf("%w: receiver verification for %s doesn't match", ErrReceiverVerificationMismatch, receiver.PhoneNumber) + } + err = di.receiverVerificationModel.UpdateVerificationValue(ctx, dbTx, verification.ReceiverID, verification.VerificationField, instruction.VerificationValue) + + if err != nil { + return fmt.Errorf("error updating receiver verification for disbursement id %s: %w", disbursement.ID, err) + } + } + } + } + + // Step 3: Fetch all receiver wallets and create missing ones + receiverWallets, err := di.receiverWalletModel.GetByReceiverIDsAndWalletID(ctx, dbTx, receiverIDs, disbursement.Wallet.ID) + if err != nil { + return fmt.Errorf("error fetching receiver wallets: %w", err) + } + receiverWalletsMap := make(map[string]string) + for _, receiverWallet := range receiverWallets { + receiverWalletsMap[receiverWallet.Receiver.ID] = receiverWallet.ID + } + + for _, receiverId := range receiverIDs { + _, exists := receiverWalletsMap[receiverId] + if !exists { + receiverWalletInsert := ReceiverWalletInsert{ + ReceiverID: receiverId, + WalletID: disbursement.Wallet.ID, + } + walletID, insertErr := di.receiverWalletModel.Insert(ctx, dbTx, receiverWalletInsert) + if insertErr != nil { + return fmt.Errorf("error inserting receiver wallet for receiver id %s: %w", receiverId, insertErr) + } + receiverWalletsMap[receiverId] = walletID + } + } + + // Step 4: Delete all payments tied to this disbursement for each receiver in one call + if err = di.paymentModel.DeleteAllForDisbursement(ctx, dbTx, disbursement.ID); err != nil { + return fmt.Errorf("error deleting payments: %w", err) + } + + // Step 5: Create payments for all receivers + payments := make([]PaymentInsert, 0, len(instructions)) + for _, instruction := range instructions { + receiver := receiverMap[instruction.Phone] + payment := PaymentInsert{ + ReceiverID: receiver.ID, + DisbursementID: disbursement.ID, + Amount: instruction.Amount, + AssetID: disbursement.Asset.ID, + ReceiverWalletID: receiverWalletsMap[receiver.ID], + } + payments = append(payments, payment) + } + if err = di.paymentModel.InsertAll(ctx, dbTx, payments); err != nil { + return fmt.Errorf("error inserting payments: %w", err) + } + + // Step 6: Persist Payment file to Disbursement + if err = di.disbursementModel.Update(ctx, update); err != nil { + return fmt.Errorf("error persisting payment file: %w", err) + } + + // Step 7: Update Disbursement Status + if err = di.disbursementModel.UpdateStatus(ctx, dbTx, userID, disbursement.ID, ReadyDisbursementStatus); err != nil { + return fmt.Errorf("error updating status: %w", err) + } + + return nil + }) +} diff --git a/internal/data/disbursement_instructions_test.go b/internal/data/disbursement_instructions_test.go new file mode 100644 index 000000000..c6530ff22 --- /dev/null +++ b/internal/data/disbursement_instructions_test.go @@ -0,0 +1,210 @@ +package data + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DisbursementInstructionModel_ProcessAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, &DisbursementModel{dbConnectionPool: dbConnectionPool}, Disbursement{ + Name: "disbursement1", + Asset: asset, + Country: country, + Wallet: wallet, + }) + + di := NewDisbursementInstructionModel(dbConnectionPool) + + instruction1 := DisbursementInstruction{ + Phone: "+380-12-345-671", + Amount: "100.01", + ID: "123456781", + VerificationValue: "1990-01-01", + } + + instruction2 := DisbursementInstruction{ + Phone: "+380-12-345-672", + Amount: "100.02", + ID: "123456782", + VerificationValue: "1990-01-02", + } + + instruction3 := DisbursementInstruction{ + Phone: "+380-12-345-673", + Amount: "100.03", + ID: "123456783", + VerificationValue: "1990-01-03", + } + instructions := []*DisbursementInstruction{&instruction1, &instruction2, &instruction3} + expectedPhoneNumbers := []string{instruction1.Phone, instruction2.Phone, instruction3.Phone} + expectedExternalIDs := []string{instruction1.ID, instruction2.ID, instruction3.ID} + expectedPayments := []string{instruction1.Amount, instruction2.Amount, instruction3.Amount} + + disbursementUpdate := &DisbursementUpdate{ + ID: disbursement.ID, + FileName: "instructions.csv", + FileContent: CreateInstructionsFixture(t, instructions), + } + + t.Run("success", func(t *testing.T) { + err := di.ProcessAll(ctx, "user-id", instructions, disbursement, disbursementUpdate, MaxInstructionsPerDisbursement) + require.NoError(t, err) + + // Verify Receivers + receivers, err := di.receiverModel.GetByPhoneNumbers(ctx, dbConnectionPool, []string{instruction1.Phone, instruction2.Phone, instruction3.Phone}) + require.NoError(t, err) + assertEqualReceivers(t, expectedPhoneNumbers, expectedExternalIDs, receivers) + + // Verify ReceiverVerifications + receiverVerifications, err := di.receiverVerificationModel.GetByReceiverIdsAndVerificationField(ctx, dbConnectionPool, []string{receivers[0].ID, receivers[1].ID, receivers[2].ID}, VerificationFieldDateOfBirth) + require.NoError(t, err) + assertEqualVerifications(t, instructions, receiverVerifications, receivers) + + // Verify ReceiverWallets + receiverWallets, err := di.receiverWalletModel.GetByReceiverIDsAndWalletID(ctx, dbConnectionPool, []string{receivers[0].ID, receivers[1].ID, receivers[2].ID}, wallet.ID) + require.NoError(t, err) + assert.Len(t, receiverWallets, len(receivers)) + for _, receiverWallet := range receiverWallets { + assert.Equal(t, wallet.ID, receiverWallet.Wallet.ID) + assert.Equal(t, DraftReceiversWalletStatus, receiverWallet.Status) + } + + // Verify Payments + actualPayments := GetPaymentsByDisbursementID(t, ctx, dbConnectionPool, disbursement.ID) + assert.Equal(t, expectedPayments, actualPayments) + + // Verify Disbursement + actualDisbursement, err := di.disbursementModel.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + require.Equal(t, ReadyDisbursementStatus, actualDisbursement.Status) + require.Equal(t, disbursementUpdate.FileContent, actualDisbursement.FileContent) + require.Equal(t, disbursementUpdate.FileName, actualDisbursement.FileName) + }) + + t.Run("success - Not confirmed Verification Value updated", func(t *testing.T) { + // process instructions for the first time + err := di.ProcessAll(ctx, "user-id", instructions, disbursement, disbursementUpdate, MaxInstructionsPerDisbursement) + require.NoError(t, err) + + instruction1.VerificationValue = "1990-01-04" + err = di.ProcessAll(ctx, "user-id", instructions, disbursement, disbursementUpdate, MaxInstructionsPerDisbursement) + require.NoError(t, err) + + // Verify Receivers + receivers, err := di.receiverModel.GetByPhoneNumbers(ctx, dbConnectionPool, []string{instruction1.Phone, instruction2.Phone, instruction3.Phone}) + require.NoError(t, err) + assertEqualReceivers(t, expectedPhoneNumbers, expectedExternalIDs, receivers) + + // Verify ReceiverVerifications + receiverVerifications, err := di.receiverVerificationModel.GetByReceiverIdsAndVerificationField(ctx, dbConnectionPool, []string{receivers[0].ID, receivers[1].ID, receivers[2].ID}, VerificationFieldDateOfBirth) + require.NoError(t, err) + assertEqualVerifications(t, instructions, receiverVerifications, receivers) + + // Verify Disbursement + actualDisbursement, err := di.disbursementModel.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + require.Equal(t, ReadyDisbursementStatus, actualDisbursement.Status) + require.Equal(t, disbursementUpdate.FileContent, actualDisbursement.FileContent) + require.Equal(t, disbursementUpdate.FileName, actualDisbursement.FileName) + }) + + t.Run("failure - Too many instructions", func(t *testing.T) { + err := di.ProcessAll(ctx, "user-id", instructions, disbursement, disbursementUpdate, 2) + require.EqualError(t, err, "maximum number of instructions exceeded") + }) + + t.Run("failure - Confirmed Verification Value not matching", func(t *testing.T) { + // process instructions for the first time + err := di.ProcessAll(ctx, "user-id", instructions, disbursement, disbursementUpdate, MaxInstructionsPerDisbursement) + require.NoError(t, err) + + receivers, err := di.receiverModel.GetByPhoneNumbers(ctx, dbConnectionPool, []string{instruction1.Phone, instruction2.Phone, instruction3.Phone}) + require.NoError(t, err) + receiversMap := make(map[string]*Receiver) + for _, receiver := range receivers { + receiversMap[receiver.PhoneNumber] = receiver + } + + // confirm a verification + ConfirmVerificationForRecipient(t, ctx, dbConnectionPool, receiversMap[instruction1.Phone].ID) + + // process instructions with mismatched verification values + instruction1.VerificationValue = "1990-01-07" + err = di.ProcessAll(ctx, "user-id", instructions, disbursement, disbursementUpdate, MaxInstructionsPerDisbursement) + require.Error(t, err) + assert.EqualError(t, err, "running atomic function in RunInTransactionWithResult: receiver verification mismatch: receiver verification for +380-12-345-671 doesn't match") + }) +} + +func assertEqualReceivers(t *testing.T, expectedPhones, expectedExternalIDs []string, actualReceivers []*Receiver) { + assert.Len(t, actualReceivers, len(expectedPhones)) + + for _, actual := range actualReceivers { + assert.Contains(t, expectedPhones, actual.PhoneNumber) + assert.Contains(t, expectedExternalIDs, actual.ExternalID) + } +} + +func assertEqualVerifications(t *testing.T, expectedInstructions []*DisbursementInstruction, actualVerifications []*ReceiverVerification, receivers []*Receiver) { + assert.Len(t, actualVerifications, len(expectedInstructions)) + + instructionsMap := make(map[string]*DisbursementInstruction) + for _, instruction := range expectedInstructions { + instructionsMap[instruction.Phone] = instruction + } + phonesByReceiverId := make(map[string]string) + for _, receiver := range receivers { + phonesByReceiverId[receiver.ID] = receiver.PhoneNumber + } + + for _, actual := range actualVerifications { + instruction := instructionsMap[phonesByReceiverId[actual.ReceiverID]] + verified := CompareVerificationValue(actual.HashedValue, instruction.VerificationValue) + assert.True(t, verified) + } +} + +func ConfirmVerificationForRecipient(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool, receiverID string) { + query := ` + UPDATE + receiver_verifications + SET + confirmed_at = now() + WHERE + receiver_id = $1 + ` + _, err := dbConnectionPool.ExecContext(ctx, query, receiverID) + require.NoError(t, err) +} + +func GetPaymentsByDisbursementID(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool, disbursementID string) []string { + query := ` + SELECT + ROUND(p.amount, 2) + FROM + payments p + WHERE p.disbursement_id = $1 + ` + var payments []string + err := dbConnectionPool.SelectContext(ctx, &payments, query, disbursementID) + require.NoError(t, err) + return payments +} diff --git a/internal/data/disbursement_receivers.go b/internal/data/disbursement_receivers.go new file mode 100644 index 000000000..4d450e875 --- /dev/null +++ b/internal/data/disbursement_receivers.go @@ -0,0 +1,97 @@ +package data + +import ( + "context" + "fmt" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type DisbursementReceiver struct { + ID string `json:"id" db:"id"` + Email string `json:"email,omitempty" db:"email"` + PhoneNumber string `json:"phone_number" db:"phone_number"` + ExternalID string `json:"external_id" db:"external_id"` + ReceiverWallet *ReceiverWallet `json:"receiver_wallet" db:"receiver_wallet"` + Payment *Payment `json:"payment" db:"payment"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type DisbursementReceiverModel struct { + dbConnectionPool db.DBConnectionPool +} + +func (m DisbursementReceiverModel) Count(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) (int, error) { + var count int + query := ` + SELECT + count(*) + FROM + receivers r + JOIN payments p ON r.id = p.receiver_id + WHERE p.disbursement_id = $1 + ` + + err := sqlExec.GetContext(ctx, &count, query, disbursementID) + if err != nil { + return 0, fmt.Errorf("error counting disbursement receivers for disbursement ID %s: %w", disbursementID, err) + } + return count, nil +} + +func (m DisbursementReceiverModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams, disbursementID string) ([]*DisbursementReceiver, error) { + var receivers []*DisbursementReceiver + baseQuery := ` + SELECT + r.id, + r.phone_number, + r.external_id, + COALESCE(r.email, '') as email, + r.created_at, + r.updated_at, + rw.id as "receiver_wallet.id", + rw.receiver_id as "receiver_wallet.receiver.id", + COALESCE(rw.stellar_address, '') as "receiver_wallet.stellar_address", + COALESCE(rw.stellar_memo, '') as "receiver_wallet.stellar_memo", + COALESCE(rw.stellar_memo_type, '') as "receiver_wallet.stellar_memo_type", + rw.status as "receiver_wallet.status", + rw.created_at as "receiver_wallet.created_at", + rw.updated_at as "receiver_wallet.updated_at", + w.id as "receiver_wallet.wallet.id", + w.name as "receiver_wallet.wallet.name", + p.id as "payment.id", + p.amount as "payment.amount", + p.status as "payment.status", + COALESCE(p.stellar_transaction_id, '') as "payment.stellar_transaction_id", + COALESCE(p.stellar_operation_id, '') as "payment.stellar_operation_id", + p.created_at as "payment.created_at", + p.updated_at as "payment.updated_at", + a.id as "payment.asset.id", + a.code as "payment.asset.code", + a.issuer as "payment.asset.issuer" + FROM + receivers r + JOIN payments p ON r.id = p.receiver_id + JOIN receiver_wallets rw ON rw.id = p.receiver_wallet_id + JOIN wallets w ON rw.wallet_id = w.id + JOIN assets a ON p.asset_id = a.id + ` + + query, params := m.newDisbursementReceiversQuery(baseQuery, queryParams, disbursementID) + err := sqlExec.SelectContext(ctx, &receivers, query, params...) + if err != nil { + return nil, fmt.Errorf("error getting receivers: %w", err) + } + return receivers, nil +} + +func (m DisbursementReceiverModel) newDisbursementReceiversQuery(baseQuery string, queryParams *QueryParams, disbursementID string) (string, []interface{}) { + qb := NewQueryBuilder(baseQuery) + qb.AddCondition("p.disbursement_id = ?", disbursementID) + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "r") + qb.AddPagination(queryParams.Page, queryParams.PageLimit) + query, params := qb.Build() + return m.dbConnectionPool.Rebind(query), params +} diff --git a/internal/data/disbursement_receivers_test.go b/internal/data/disbursement_receivers_test.go new file mode 100644 index 000000000..dbf0d1df5 --- /dev/null +++ b/internal/data/disbursement_receivers_test.go @@ -0,0 +1,134 @@ +package data + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/require" +) + +func Test_DisbursementReceiverModel_Count(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := &DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursementReceiverModel := &DisbursementReceiverModel{dbConnectionPool: dbConnectionPool} + paymentModel := &PaymentModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, disbursementModel, &Disbursement{ + Country: country, + Wallet: wallet, + Status: ReadyDisbursementStatus, + Asset: asset, + }) + + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + rwDraft1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, DraftReceiversWalletStatus) + rwDraft2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, DraftReceiversWalletStatus) + + require.NotNil(t, rwDraft1) + require.NotNil(t, rwDraft2) + + t.Run("no receivers for disbursement 1", func(t *testing.T) { + count, err := disbursementReceiverModel.Count(ctx, dbConnectionPool, disbursement1.ID) + require.NoError(t, err) + require.Equal(t, 0, count) + }) + + t.Run("count receivers for disbursement 1", func(t *testing.T) { + CreatePaymentFixture(t, ctx, dbConnectionPool, paymentModel, &Payment{ + ReceiverWallet: rwDraft1, + Disbursement: disbursement1, + Asset: *asset, + Amount: "100", + Status: DraftPaymentStatus, + }) + CreatePaymentFixture(t, ctx, dbConnectionPool, paymentModel, &Payment{ + ReceiverWallet: rwDraft2, + Disbursement: disbursement1, + Asset: *asset, + Amount: "200", + Status: DraftPaymentStatus, + }) + + count, err := disbursementReceiverModel.Count(ctx, dbConnectionPool, disbursement1.ID) + require.NoError(t, err) + require.Equal(t, 2, count) + }) +} + +func Test_DisbursementReceiverModel_GetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := &DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursementReceiverModel := &DisbursementReceiverModel{dbConnectionPool: dbConnectionPool} + paymentModel := &PaymentModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, disbursementModel, &Disbursement{ + Country: country, + Wallet: wallet, + Status: ReadyDisbursementStatus, + Asset: asset, + }) + + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + rwDraft1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, DraftReceiversWalletStatus) + rwDraft2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, DraftReceiversWalletStatus) + + require.NotNil(t, rwDraft1) + require.NotNil(t, rwDraft2) + + t.Run("no receivers for disbursement 1", func(t *testing.T) { + receivers, err := disbursementReceiverModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, disbursement1.ID) + require.NoError(t, err) + require.Equal(t, 0, len(receivers)) + }) + + t.Run("get all receivers for disbursement 1", func(t *testing.T) { + CreatePaymentFixture(t, ctx, dbConnectionPool, paymentModel, &Payment{ + ReceiverWallet: rwDraft1, + Disbursement: disbursement1, + Asset: *asset, + Amount: "100", + Status: DraftPaymentStatus, + }) + CreatePaymentFixture(t, ctx, dbConnectionPool, paymentModel, &Payment{ + ReceiverWallet: rwDraft2, + Disbursement: disbursement1, + Asset: *asset, + Amount: "200", + Status: DraftPaymentStatus, + }) + + receivers, err := disbursementReceiverModel.GetAll(ctx, dbConnectionPool, &QueryParams{}, disbursement1.ID) + require.NoError(t, err) + require.Equal(t, 2, len(receivers)) + }) +} diff --git a/internal/data/disbursements.go b/internal/data/disbursements.go new file mode 100644 index 000000000..70fd6e36f --- /dev/null +++ b/internal/data/disbursements.go @@ -0,0 +1,528 @@ +package data + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type Disbursement struct { + ID string `json:"id" db:"id"` + Name string `json:"name" db:"name"` + Country *Country `json:"country,omitempty" db:"country"` + Wallet *Wallet `json:"wallet,omitempty" db:"wallet"` + Asset *Asset `json:"asset,omitempty" db:"asset"` + Status DisbursementStatus `json:"status" db:"status"` + VerificationField VerificationField `json:"verification_field,omitempty" db:"verification_field"` + StatusHistory DisbursementStatusHistory `json:"status_history,omitempty" db:"status_history"` + FileName string `json:"file_name,omitempty" db:"file_name"` + FileContent []byte `json:"-" db:"file_content"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + *DisbursementStats +} + +type DisbursementStatusHistory []DisbursementStatusHistoryEntry + +type DisbursementStats struct { + TotalPayments int `json:"total_payments" db:"total_payments"` + SuccessfulPayments int `json:"total_payments_sent" db:"total_payments_sent"` + FailedPayments int `json:"total_payments_failed" db:"total_payments_failed"` + RemainingPayments int `json:"total_payments_remaining" db:"total_payments_remaining"` + AmountDisbursed string `json:"amount_disbursed" db:"amount_disbursed"` + TotalAmount string `json:"total_amount" db:"total_amount"` + AverageAmount string `json:"average_amount" db:"average_amount"` +} + +type DisbursementUpdate struct { + ID string + FileName string + FileContent []byte +} + +type VerificationField string + +const ( + VerificationFieldDateOfBirth VerificationField = "DATE_OF_BIRTH" + VerificationFieldPin VerificationField = "PIN" + VerificationFieldNationalID VerificationField = "NATIONAL_ID_NUMBER" +) + +type DisbursementStatusHistoryEntry struct { + UserID string `json:"user_id"` + Status DisbursementStatus `json:"status"` + Timestamp time.Time `json:"timestamp"` +} +type DisbursementModel struct { + dbConnectionPool db.DBConnectionPool +} + +var ( + DefaultDisbursementSortField = SortFieldCreatedAt + DefaultDisbursementSortOrder = SortOrderDESC + AllowedDisbursementFilters = []FilterKey{FilterKeyStatus, FilterKeyCreatedAtAfter, FilterKeyCreatedAtBefore} + AllowedDisbursementSorts = []SortField{SortFieldName, SortFieldCreatedAt} +) + +func (d *DisbursementModel) Insert(ctx context.Context, disbursement *Disbursement) (string, error) { + const q = ` + INSERT INTO + disbursements (name, status, status_history, wallet_id, asset_id, country_code) + VALUES + ($1, $2, $3, $4, $5, $6) + RETURNING id + ` + var newId string + err := d.dbConnectionPool.GetContext(ctx, &newId, q, + disbursement.Name, + disbursement.Status, + disbursement.StatusHistory, + disbursement.Wallet.ID, + disbursement.Asset.ID, + disbursement.Country.Code, + ) + if err != nil { + // check if the error is a duplicate key error + if strings.Contains(err.Error(), "duplicate key") { + return "", ErrRecordAlreadyExists + } + return "", fmt.Errorf("unable to create disbursement %s: %w", disbursement.Name, err) + } + + return newId, nil +} + +func (d *DisbursementModel) GetWithStatistics(ctx context.Context, id string) (*Disbursement, error) { + disbursement, err := d.Get(ctx, d.dbConnectionPool, id) + if err != nil { + return nil, err + } + + err = d.populateStatistics(ctx, []*Disbursement{disbursement}) + if err != nil { + return nil, fmt.Errorf("error populating statistics for disbursement ID %s: %w", id, err) + } + + return disbursement, nil +} + +func (d *DisbursementModel) Get(ctx context.Context, sqlExec db.SQLExecuter, id string) (*Disbursement, error) { + var disbursement Disbursement + + query := ` + SELECT + d.id, + d.name, + d.status, + d.status_history, + d.verification_field, + COALESCE(d.file_name, '') as file_name, + d.file_content, + d.created_at, + d.updated_at, + w.id as "wallet.id", + w.name as "wallet.name", + w.homepage as "wallet.homepage", + w.sep_10_client_domain as "wallet.sep_10_client_domain", + w.deep_link_schema as "wallet.deep_link_schema", + w.created_at as "wallet.created_at", + w.updated_at as "wallet.updated_at", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + a.created_at as "asset.created_at", + a.updated_at as "asset.updated_at", + c.code as "country.code", + c.name as "country.name", + c.created_at as "country.created_at", + c.updated_at as "country.updated_at" + FROM + disbursements d + JOIN wallets w on d.wallet_id = w.id + JOIN assets a on d.asset_id = a.id + JOIN countries c on d.country_code = c.code + WHERE + d.id = $1 + ` + err := sqlExec.GetContext(ctx, &disbursement, query, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying disbursement ID %s: %w", id, err) + } + + return &disbursement, nil +} + +func (d *DisbursementModel) GetByName(ctx context.Context, sqlExec db.SQLExecuter, name string) (*Disbursement, error) { + var disbursement Disbursement + + query := ` + SELECT + d.id, + d.name, + d.status, + d.status_history, + d.verification_field, + COALESCE(d.file_name, '') as file_name, + d.file_content, + d.created_at, + d.updated_at, + w.id as "wallet.id", + w.name as "wallet.name", + w.homepage as "wallet.homepage", + w.sep_10_client_domain as "wallet.sep_10_client_domain", + w.deep_link_schema as "wallet.deep_link_schema", + w.created_at as "wallet.created_at", + w.updated_at as "wallet.updated_at", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + a.created_at as "asset.created_at", + a.updated_at as "asset.updated_at", + c.code as "country.code", + c.name as "country.name", + c.created_at as "country.created_at", + c.updated_at as "country.updated_at" + FROM + disbursements d + JOIN wallets w on d.wallet_id = w.id + JOIN assets a on d.asset_id = a.id + JOIN countries c on d.country_code = c.code + WHERE + d.name = $1 + ` + err := sqlExec.GetContext(ctx, &disbursement, query, name) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying disbursement with name %s: %w", name, err) + } + + return &disbursement, nil +} + +// populateStatistics populates the payment statistics for the given disbursements +func (d *DisbursementModel) populateStatistics(ctx context.Context, disbursements []*Disbursement) error { + if len(disbursements) == 0 { + return nil + } + + disbursementIDs := make([]string, len(disbursements)) + for i, disbursement := range disbursements { + disbursementIDs[i] = disbursement.ID + } + + query := ` + SELECT + disbursement_id, + count(*) as total_payments, + sum(case when status = 'SUCCESS' then 1 else 0 end) as total_payments_sent, + sum(case when status = 'FAILED' then 1 else 0 end) as total_payments_failed, + sum(case when status IN ('DRAFT', 'READY', 'PENDING', 'PAUSED') then 1 else 0 end) as total_payments_remaining, + ROUND(SUM(CASE WHEN status = 'SUCCESS' THEN amount ELSE 0 END), 2) as amount_disbursed, + ROUND(SUM(amount), 2) as total_amount, + ROUND(avg(amount), 2) as average_amount + FROM + payments + WHERE + disbursement_id = ANY ($1) + GROUP BY + disbursement_id; + ` + + rows, err := d.dbConnectionPool.QueryxContext(ctx, query, pq.Array(disbursementIDs)) + if err != nil { + return fmt.Errorf("error querying disbursement statistics: %w", err) + } + defer db.CloseRows(ctx, rows) + + statistics := make(map[string]*DisbursementStats) + for rows.Next() { + var disbursementID string + var stats DisbursementStats + err := rows.Scan( + &disbursementID, + &stats.TotalPayments, + &stats.SuccessfulPayments, + &stats.FailedPayments, + &stats.RemainingPayments, + &stats.AmountDisbursed, + &stats.TotalAmount, + &stats.AverageAmount, + ) + if err != nil { + return fmt.Errorf("error scanning disbursement statistics: %w", err) + } + statistics[disbursementID] = &stats + } + + if len(statistics) == 0 { + return nil + } + + // populate the statistics + for _, disbursement := range disbursements { + disbursement.DisbursementStats = statistics[disbursement.ID] + } + return nil +} + +// Count returns the number of disbursements matching the given query parameters. +func (d *DisbursementModel) Count(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) (int, error) { + var count int + baseQuery := ` + SELECT + count(*) + FROM + disbursements d + JOIN wallets w on d.wallet_id = w.id + JOIN assets a on d.asset_id = a.id + JOIN countries c on d.country_code = c.code + ` + + query, params := d.newDisbursementQuery(baseQuery, queryParams, false) + + err := sqlExec.GetContext(ctx, &count, query, params...) + if err != nil { + return 0, fmt.Errorf("error counting disbursements: %w", err) + } + return count, nil +} + +// GetAll returns all disbursements matching the given query parameters. +func (d *DisbursementModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) ([]*Disbursement, error) { + disbursements := []*Disbursement{} + + baseQuery := ` + SELECT + d.id, + d.name, + d.status, + d.status_history, + d.verification_field, + d.created_at, + d.updated_at, + COALESCE(d.file_name, '') as file_name, + w.id as "wallet.id", + w.name as "wallet.name", + w.homepage as "wallet.homepage", + w.sep_10_client_domain as "wallet.sep_10_client_domain", + w.deep_link_schema as "wallet.deep_link_schema", + w.created_at as "wallet.created_at", + w.updated_at as "wallet.updated_at", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + a.created_at as "asset.created_at", + a.updated_at as "asset.updated_at", + c.code as "country.code", + c.name as "country.name", + c.created_at as "country.created_at", + c.updated_at as "country.updated_at" + FROM + disbursements d + JOIN wallets w on d.wallet_id = w.id + JOIN assets a on d.asset_id = a.id + JOIN countries c on d.country_code = c.code + ` + + query, params := d.newDisbursementQuery(baseQuery, queryParams, true) + err := sqlExec.SelectContext(ctx, &disbursements, query, params...) + if err != nil { + return nil, fmt.Errorf("error querying disbursements: %w", err) + } + + // populate the statistics + if err = d.populateStatistics(ctx, disbursements); err != nil { + return nil, fmt.Errorf("error populating disbursement statistics: %w", err) + } + return disbursements, nil +} + +// UpdateStatus updates the status of the given disbursement. +func (d *DisbursementModel) UpdateStatus(ctx context.Context, sqlExec db.SQLExecuter, userID string, disbursementID string, targetStatus DisbursementStatus) error { + sourceStatuses := targetStatus.SourceStatuses() + + query := ` + UPDATE + disbursements + SET + status = $1, + status_history = array_append(status_history, create_disbursement_status_history(NOW(), $1, $2)) + WHERE + id = $3 AND status = ANY($4) + ` + result, err := sqlExec.ExecContext(ctx, query, targetStatus, userID, disbursementID, pq.Array(sourceStatuses)) + if err != nil { + return fmt.Errorf("error updating disbursement status: %w", err) + } + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + if numRowsAffected == 0 { + return fmt.Errorf("disbursement %s status was not updated from %s to %s", disbursementID, sourceStatuses, targetStatus) + } else if numRowsAffected == 1 { + log.Ctx(ctx).Infof("Set disbursement %s status from %s to %s", disbursementID, sourceStatuses, targetStatus) + } else { + return fmt.Errorf("unexpected number of rows affected: %d when updating disbursement %s status from %s to %s", + numRowsAffected, + disbursementID, + sourceStatuses, + targetStatus) + } + + return nil +} + +// newDisbursementQuery generates the full query and parameters for a disbursement search query +func (d *DisbursementModel) newDisbursementQuery(baseQuery string, queryParams *QueryParams, paginated bool) (string, []interface{}) { + qb := NewQueryBuilder(baseQuery) + if queryParams.Query != "" { + qb.AddCondition("d.name ILIKE ?", "%"+queryParams.Query+"%") + } + + if statusSlice, ok := queryParams.Filters[FilterKeyStatus].([]DisbursementStatus); ok && len(statusSlice) > 0 { + qb.AddCondition("d.status = ANY(?)", pq.Array(statusSlice)) + } + if queryParams.Filters[FilterKeyCreatedAtAfter] != nil { + qb.AddCondition("d.created_at >= ?", queryParams.Filters[FilterKeyCreatedAtAfter]) + } + if queryParams.Filters[FilterKeyCreatedAtBefore] != nil { + qb.AddCondition("d.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore]) + } + if paginated { + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "d") + qb.AddPagination(queryParams.Page, queryParams.PageLimit) + } + query, params := qb.Build() + return d.dbConnectionPool.Rebind(query), params +} + +func (du *DisbursementUpdate) Validate() error { + if du.FileName == "" { + return errors.New("file name is required") + } + if len(du.FileContent) == 0 { + return errors.New("file content is required") + } + if du.ID == "" { + return errors.New("disbursement ID is required") + } + return nil +} + +func (d *DisbursementModel) Update(ctx context.Context, du *DisbursementUpdate) error { + if err := du.Validate(); err != nil { + return fmt.Errorf("error validating disbursement update: %w", err) + } + + query := ` + UPDATE + disbursements + SET + file_name = $1, + file_content = $2 + WHERE + id = $3 + ` + result, err := d.dbConnectionPool.ExecContext(ctx, query, du.FileName, du.FileContent, du.ID) + if err != nil { + return fmt.Errorf("error updating disbursement: %w", err) + } + + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + if numRowsAffected != 1 { + return fmt.Errorf("disbursement %s was not updated", du.ID) + } + + return nil +} + +// Value implements the driver.Valuer interface. +func (dsh DisbursementStatusHistory) Value() (driver.Value, error) { + var statusHistoryJSON []string + for _, sh := range dsh { + shJSONBytes, err := json.Marshal(sh) + if err != nil { + return nil, fmt.Errorf("error converting status history to json for disbursement: %w", err) + } + statusHistoryJSON = append(statusHistoryJSON, string(shJSONBytes)) + } + + return pq.Array(statusHistoryJSON).Value() +} + +// Scan implements the sql.Scanner interface. +func (dsh *DisbursementStatusHistory) Scan(src interface{}) error { + var statusHistoryJSON []string + if err := pq.Array(&statusHistoryJSON).Scan(src); err != nil { + return fmt.Errorf("error scanning status history value: %w", err) + } + + for _, sh := range statusHistoryJSON { + var shEntry DisbursementStatusHistoryEntry + err := json.Unmarshal([]byte(sh), &shEntry) + if err != nil { + return fmt.Errorf("error unmarshaling status_history column: %w", err) + } + *dsh = append(*dsh, shEntry) + } + + return nil +} + +// CompleteDisbursements sets disbursements statuses to complete after all payments are processed and successfully sent. +func (d *DisbursementModel) CompleteDisbursements(ctx context.Context, sqlExec db.SQLExecuter, disbursementIDs []string) error { + query := ` + WITH incompleted_disbursements AS ( + SELECT + p.disbursement_id, + COUNT(p.*) + FROM + payments p + INNER JOIN disbursements d ON d.id = p.disbursement_id + WHERE + p.status != $4 + AND d.status = $3 + AND d.id = ANY($2) + GROUP BY + p.status, + p.disbursement_id + HAVING + COUNT(p.*) > 0 + ) + UPDATE + disbursements + SET + status = $1, + status_history = array_append(status_history, create_disbursement_status_history(NOW(), $1, '')) + WHERE + id = ANY($2) + AND status = $3 + AND id NOT IN (SELECT disbursement_id FROM incompleted_disbursements) + ` + + _, err := sqlExec.ExecContext(ctx, query, CompletedDisbursementStatus, pq.Array(disbursementIDs), StartedDisbursementStatus, SuccessPaymentStatus) + if err != nil { + return fmt.Errorf("error completing disbursement: %w", err) + } + + return nil +} diff --git a/internal/data/disbursements_state_machine_test.go b/internal/data/disbursements_state_machine_test.go new file mode 100644 index 000000000..eba63a02e --- /dev/null +++ b/internal/data/disbursements_state_machine_test.go @@ -0,0 +1,180 @@ +package data + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_DisbursementStatus_ToDisbursementStatus(t *testing.T) { + tests := []struct { + name string + actual string + want DisbursementStatus + err error + }{ + { + name: "valid entry", + actual: "STARTED", + want: StartedDisbursementStatus, + err: nil, + }, + { + name: "valid lower case", + actual: "draft", + want: DraftDisbursementStatus, + err: nil, + }, + { + name: "valid weird case", + actual: "ReAdY", + want: ReadyDisbursementStatus, + err: nil, + }, + { + name: "invalid entry", + actual: "NOT_VALID", + want: StartedDisbursementStatus, + err: fmt.Errorf("invalid disbursement status: NOT_VALID"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToDisbursementStatus(tt.actual) + + if tt.err != nil { + require.EqualError(t, err, tt.err.Error()) + return + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +func Test_DisbursementStatus_TransitionTo(t *testing.T) { + tests := []struct { + name string + actual DisbursementStatus + target DisbursementStatus + err error + }{ + { + name: "instructions uploaded successfully transition", + actual: DraftDisbursementStatus, + target: ReadyDisbursementStatus, + err: nil, + }, + { + name: "user re-uploads instructions transition", + actual: ReadyDisbursementStatus, + target: ReadyDisbursementStatus, + err: nil, + }, + { + name: "instructions uploaded successfully transition", + actual: DraftDisbursementStatus, + target: ReadyDisbursementStatus, + err: nil, + }, + { + name: "user starts disbursement transition", + actual: ReadyDisbursementStatus, + target: StartedDisbursementStatus, + err: nil, + }, + { + name: "user pauses disbursement transition", + actual: StartedDisbursementStatus, + target: PausedDisbursementStatus, + err: nil, + }, + { + name: "user resumes disbursement transition", + actual: PausedDisbursementStatus, + target: StartedDisbursementStatus, + err: nil, + }, + { + name: "all payments went through transition", + actual: StartedDisbursementStatus, + target: CompletedDisbursementStatus, + err: nil, + }, + { + name: "invalid transition 1", + actual: DraftDisbursementStatus, + target: StartedDisbursementStatus, + err: fmt.Errorf("cannot transition from DRAFT to STARTED"), + }, + { + name: "invalid transition 2", + actual: StartedDisbursementStatus, + target: DraftDisbursementStatus, + err: fmt.Errorf("cannot transition from STARTED to DRAFT"), + }, + { + name: "invalid transition 3", + actual: DraftDisbursementStatus, + target: PausedDisbursementStatus, + err: fmt.Errorf("cannot transition from DRAFT to PAUSED"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.actual.TransitionTo(tt.target) + if tt.err != nil { + require.EqualError(t, err, tt.err.Error()) + return + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_DisbursementStatus_SourceStatuses(t *testing.T) { + tests := []struct { + name string + targetStatus DisbursementStatus + expectedSourceStatuses []DisbursementStatus + }{ + { + name: "Draft", + targetStatus: DraftDisbursementStatus, + expectedSourceStatuses: []DisbursementStatus{}, + }, + { + name: "Ready", + targetStatus: ReadyDisbursementStatus, + expectedSourceStatuses: []DisbursementStatus{DraftDisbursementStatus, ReadyDisbursementStatus}, + }, + { + name: "Started", + targetStatus: StartedDisbursementStatus, + expectedSourceStatuses: []DisbursementStatus{ReadyDisbursementStatus, PausedDisbursementStatus}, + }, + { + name: "Paused", + targetStatus: PausedDisbursementStatus, + expectedSourceStatuses: []DisbursementStatus{StartedDisbursementStatus}, + }, + { + name: "Completed", + targetStatus: CompletedDisbursementStatus, + expectedSourceStatuses: []DisbursementStatus{StartedDisbursementStatus}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expectedSourceStatuses, tt.targetStatus.SourceStatuses()) + }) + } +} + +func Test_DisbursementStatus_DisbursementStatuses(t *testing.T) { + expectedStatuses := []DisbursementStatus{DraftDisbursementStatus, ReadyDisbursementStatus, StartedDisbursementStatus, PausedDisbursementStatus, CompletedDisbursementStatus} + require.Equal(t, expectedStatuses, DisbursementStatuses()) +} diff --git a/internal/data/disbursements_test.go b/internal/data/disbursements_test.go new file mode 100644 index 000000000..bb6881b07 --- /dev/null +++ b/internal/data/disbursements_test.go @@ -0,0 +1,645 @@ +package data + +import ( + "context" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DisbursementModelInsert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := Disbursement{ + Name: "disbursement1", + Status: DraftDisbursementStatus, + StatusHistory: []DisbursementStatusHistoryEntry{ + { + Status: DraftDisbursementStatus, + UserID: "user1", + }, + }, + Asset: asset, + Country: country, + Wallet: wallet, + } + + t.Run("returns error when disbursement already exists is not found", func(t *testing.T) { + _, err := disbursementModel.Insert(ctx, &disbursement) + require.NoError(t, err) + _, err = disbursementModel.Insert(ctx, &disbursement) + require.Error(t, err) + require.Equal(t, ErrRecordAlreadyExists, err) + }) + + t.Run("insert disbursement successfully", func(t *testing.T) { + disbursement.Name = "disbursement2" + id, err := disbursementModel.Insert(ctx, &disbursement) + require.NoError(t, err) + require.NotNil(t, id) + + actual, err := disbursementModel.Get(ctx, dbConnectionPool, id) + require.NoError(t, err) + + assert.Equal(t, "disbursement2", actual.Name) + assert.Equal(t, DraftDisbursementStatus, actual.Status) + assert.Equal(t, asset, actual.Asset) + assert.Equal(t, country, actual.Country) + assert.Equal(t, wallet, actual.Wallet) + assert.Equal(t, 1, len(actual.StatusHistory)) + assert.Equal(t, DraftDisbursementStatus, actual.StatusHistory[0].Status) + assert.Equal(t, "user1", actual.StatusHistory[0].UserID) + }) +} + +func Test_DisbursementModelCount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := Disbursement{ + Status: DraftDisbursementStatus, + StatusHistory: []DisbursementStatusHistoryEntry{ + { + Status: DraftDisbursementStatus, + UserID: "user1", + }, + }, + Asset: asset, + Country: country, + Wallet: wallet, + } + + t.Run("returns 0 when no disbursements exist", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + count, err := disbursementModel.Count(ctx, dbConnectionPool, &QueryParams{}) + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + + t.Run("returns count of disbursements", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + count, err := disbursementModel.Count(ctx, dbConnectionPool, &QueryParams{}) + require.NoError(t, err) + assert.Equal(t, 2, count) + }) + + t.Run("returns count of disbursements", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + count, err := disbursementModel.Count(ctx, dbConnectionPool, &QueryParams{Query: "2"}) + require.NoError(t, err) + assert.Equal(t, 1, count) + }) +} + +func Test_DisbursementModelGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := Disbursement{ + Name: "disbursement1", + Status: DraftDisbursementStatus, + StatusHistory: []DisbursementStatusHistoryEntry{ + { + Status: DraftDisbursementStatus, + UserID: "user1", + }, + }, + Asset: asset, + Country: country, + Wallet: wallet, + } + + t.Run("returns error when disbursement does not exist", func(t *testing.T) { + _, err := disbursementModel.Get(ctx, dbConnectionPool, "invalid") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns disbursement successfully", func(t *testing.T) { + expected := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + actual, err := disbursementModel.Get(ctx, dbConnectionPool, expected.ID) + require.NoError(t, err) + + assert.Equal(t, *expected, *actual) + }) +} + +func Test_DisbursementModelGetByName(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := Disbursement{ + Name: "disbursement1", + Status: DraftDisbursementStatus, + StatusHistory: []DisbursementStatusHistoryEntry{ + { + Status: DraftDisbursementStatus, + UserID: "user1", + }, + }, + Asset: asset, + Country: country, + Wallet: wallet, + } + + t.Run("returns error when disbursement does not exist", func(t *testing.T) { + _, err := disbursementModel.GetByName(ctx, dbConnectionPool, "invalid") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns disbursement get by name successfully", func(t *testing.T) { + expected := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + actual, err := disbursementModel.GetByName(ctx, dbConnectionPool, expected.Name) + require.NoError(t, err) + + assert.Equal(t, *expected, *actual) + }) +} + +func Test_DisbursementModelGetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := Disbursement{ + Status: DraftDisbursementStatus, + StatusHistory: []DisbursementStatusHistoryEntry{ + { + Status: DraftDisbursementStatus, + UserID: "user1", + }, + }, + Asset: asset, + Country: country, + Wallet: wallet, + } + + t.Run("returns empty list when no disbursements exist", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + disbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}) + require.NoError(t, err) + assert.Equal(t, 0, len(disbursements)) + }) + + t.Run("returns disbursements successfully", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + expected1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}) + require.NoError(t, err) + assert.Equal(t, 2, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expected1, expected2}, actualDisbursements) + }) + + t.Run("returns disbursements successfully with limit", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + expected1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 1, PageLimit: 1}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expected1}, actualDisbursements) + }) + + t.Run("returns disbursements successfully with offset", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Page: 2, PageLimit: 1}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expected2}, actualDisbursements) + }) + + t.Run("returns disbursements successfully with order", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + expected1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{SortBy: SortFieldName, SortOrder: SortOrderDESC}) + require.NoError(t, err) + assert.Equal(t, 2, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expected2, expected1}, actualDisbursements) + }) + + t.Run("returns disbursements successfully with filter", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + disbursement.Status = DraftDisbursementStatus + expected1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + disbursement.Status = CompletedDisbursementStatus + CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + filters := map[FilterKey]interface{}{ + FilterKeyStatus: []DisbursementStatus{DraftDisbursementStatus}, + } + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expected1}, actualDisbursements) + }) + + t.Run("returns disbursements successfully with statuses parameter ", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement1" + disbursement.Status = DraftDisbursementStatus + disbursement.CreatedAt = time.Date(2023, 1, 30, 0, 0, 0, 0, time.UTC) + expected1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + disbursement.Name = "disbursement2" + disbursement.Status = CompletedDisbursementStatus + disbursement.CreatedAt = time.Date(2023, 3, 30, 0, 0, 0, 0, time.UTC) + expected2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + filters := map[FilterKey]interface{}{ + FilterKeyStatus: []DisbursementStatus{DraftDisbursementStatus, CompletedDisbursementStatus}, + } + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{Filters: filters, SortBy: SortFieldCreatedAt, SortOrder: SortOrderDESC}) + + require.NoError(t, err) + assert.Equal(t, 2, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expected2, expected1}, actualDisbursements) + }) + t.Run("returns disbursements successfully with stats", func(t *testing.T) { + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + expectedDisbursement := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + ReceiverWallet: receiverWallet, + Disbursement: expectedDisbursement, + Asset: *asset, + Amount: "100", + Status: SuccessPaymentStatus, + }) + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + ReceiverWallet: receiverWallet, + Disbursement: expectedDisbursement, + Asset: *asset, + Amount: "150.05", + Status: DraftPaymentStatus, + }) + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + ReceiverWallet: receiverWallet, + Disbursement: expectedDisbursement, + Asset: *asset, + Amount: "020.50", + Status: FailedPaymentStatus, + }) + + expectedStats := &DisbursementStats{} + expectedStats.TotalPayments = 3 + expectedStats.SuccessfulPayments = 1 + expectedStats.FailedPayments = 1 + expectedStats.RemainingPayments = 1 + expectedStats.TotalAmount = "270.55" + expectedStats.AmountDisbursed = "100.00" + expectedStats.AverageAmount = "90.18" + + expectedDisbursement.DisbursementStats = expectedStats + + actualDisbursements, err := disbursementModel.GetAll(ctx, dbConnectionPool, &QueryParams{}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualDisbursements)) + assert.Equal(t, []*Disbursement{expectedDisbursement}, actualDisbursements) + }) +} + +func Test_DisbursementModel_Update(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := &DisbursementModel{dbConnectionPool: dbConnectionPool} + + disbursement := CreateDisbursementFixture(t, ctx, dbConnectionPool, &DisbursementModel{dbConnectionPool: dbConnectionPool}, &Disbursement{ + Name: "disbursement1", + }) + + disbursementFileContent := CreateInstructionsFixture(t, []*DisbursementInstruction{ + {"1234567890", "1", "123.12", "1995-02-20"}, + {"0987654321", "2", "321", "1974-07-19"}, + {"0987654321", "3", "321", "1974-07-19"}, + }) + + t.Run("update instructions", func(t *testing.T) { + err := disbursementModel.Update(ctx, &DisbursementUpdate{ + ID: disbursement.ID, + FileContent: disbursementFileContent, + FileName: "instructions.csv", + }) + require.NoError(t, err) + actual, err := disbursementModel.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + require.Equal(t, "instructions.csv", actual.FileName) + require.NotEmpty(t, actual.FileContent) + require.Equal(t, disbursementFileContent, actual.FileContent) + }) + + t.Run("no disbursement ID in update", func(t *testing.T) { + err := disbursementModel.Update(ctx, &DisbursementUpdate{ + FileContent: disbursementFileContent, + FileName: "instructions.csv", + }) + require.ErrorContains(t, err, "disbursement ID is required") + }) + + t.Run("no file name in update", func(t *testing.T) { + err := disbursementModel.Update(ctx, &DisbursementUpdate{ + FileContent: disbursementFileContent, + ID: disbursement.ID, + }) + require.ErrorContains(t, err, "file name is required") + }) + + t.Run("no file content in update", func(t *testing.T) { + err := disbursementModel.Update(ctx, &DisbursementUpdate{ + FileName: "instructions.csv", + ID: disbursement.ID, + }) + require.ErrorContains(t, err, "file content is required") + }) + + t.Run("empty file content in update", func(t *testing.T) { + err := disbursementModel.Update(ctx, &DisbursementUpdate{ + FileName: "instructions.csv", + ID: disbursement.ID, + FileContent: []byte{}, + }) + require.ErrorContains(t, err, "file content is required") + }) + + t.Run("wrong disbursement ID", func(t *testing.T) { + err := disbursementModel.Update(ctx, &DisbursementUpdate{ + FileName: "instructions.csv", + ID: "9e0ff65f-f6e9-46e9-bf03-dc46723e3bfb", + FileContent: disbursementFileContent, + }) + require.ErrorContains(t, err, "disbursement 9e0ff65f-f6e9-46e9-bf03-dc46723e3bfb was not updated") + }) +} + +func Test_DisbursementModel_CompleteDisbursements(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + models, err := NewModels(dbConnectionPool) + require.NoError(t, err) + + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + country := CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, ReadyReceiversWalletStatus) + + t.Run("does not complete not started disbursement", func(t *testing.T) { + readyDisbursement := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Name: "disbursement ready", + Status: ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + VerificationField: VerificationFieldDateOfBirth, + }) + + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id", + StellarOperationID: "operation-id", + Status: SuccessPaymentStatus, + Disbursement: readyDisbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = models.Disbursements.CompleteDisbursements(ctx, dbConnectionPool, []string{readyDisbursement.ID}) + require.NoError(t, err) + + readyDisbursement, err = models.Disbursements.Get(ctx, dbConnectionPool, readyDisbursement.ID) + require.NoError(t, err) + assert.Equal(t, ReadyDisbursementStatus, readyDisbursement.Status) + }) + + t.Run("does not complete started disbursement if not all payments are not completed", func(t *testing.T) { + startedDisbursement := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Name: "disbursement started", + Status: StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + VerificationField: VerificationFieldDateOfBirth, + }) + + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: SuccessPaymentStatus, + Disbursement: startedDisbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: FailedPaymentStatus, + Disbursement: startedDisbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = models.Disbursements.CompleteDisbursements(ctx, dbConnectionPool, []string{startedDisbursement.ID}) + require.NoError(t, err) + + startedDisbursement, err = models.Disbursements.Get(ctx, dbConnectionPool, startedDisbursement.ID) + require.NoError(t, err) + assert.Equal(t, StartedDisbursementStatus, startedDisbursement.Status) + }) + + t.Run("completes all started disbursements after payments are successful", func(t *testing.T) { + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Name: "disbursement 1", + Status: StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + VerificationField: VerificationFieldDateOfBirth, + }) + + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id", + StellarOperationID: "operation-id", + Status: SuccessPaymentStatus, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + disbursement2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Name: "disbursement 2", + Status: StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + VerificationField: VerificationFieldDateOfBirth, + }) + + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: SuccessPaymentStatus, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + _ = CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: SuccessPaymentStatus, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = models.Disbursements.CompleteDisbursements(ctx, dbConnectionPool, []string{disbursement1.ID, disbursement2.ID}) + require.NoError(t, err) + + disbursement1, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement1.ID) + require.NoError(t, err) + assert.Equal(t, CompletedDisbursementStatus, disbursement1.Status) + + disbursement2, err = models.Disbursements.Get(ctx, dbConnectionPool, disbursement2.ID) + require.NoError(t, err) + assert.Equal(t, CompletedDisbursementStatus, disbursement2.Status) + }) +} diff --git a/internal/data/fixtures.go b/internal/data/fixtures.go new file mode 100644 index 000000000..785f0610e --- /dev/null +++ b/internal/data/fixtures.go @@ -0,0 +1,663 @@ +package data + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/csv" + "fmt" + "image" + "image/color" + "math/big" + "testing" + "time" + + "github.com/lib/pq" + "github.com/stellar/go/keypair" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + + "github.com/stretchr/testify/require" +) + +const ( + FixtureCountryUSA = "USA" + FixtureCountryUKR = "UKR" + FixtureAssetUSDC = "USDC" +) + +func CreateAssetFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, code, issuer string) *Asset { + issuerAddress := issuer + + if issuerAddress == "" { + issuer, err := utils.RandomString(56) + require.NoError(t, err) + issuerAddress = issuer + } + + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + RETURNING + id, created_at, updated_at + ` + + asset := &Asset{ + Code: code, + Issuer: issuerAddress, + } + + err := sqlExec.QueryRowxContext(ctx, query, asset.Code, asset.Issuer).Scan(&asset.ID, &asset.CreatedAt, &asset.UpdatedAt) + require.NoError(t, err) + + return asset +} + +func GetAssetFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, code string) *Asset { + const query = ` + SELECT + * + FROM + assets a + WHERE + a.code = $1 + ` + + asset := &Asset{} + err := sqlExec.GetContext(ctx, asset, query, code) + require.NoError(t, err) + + return asset +} + +// DeleteAllAssetFixtures deletes all assets in the database +func DeleteAllAssetFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM assets" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +// ClearAndCreateAssetFixtures deletes all assets in the database then creates new assets for testing +func ClearAndCreateAssetFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) []Asset { + DeleteAllAssetFixtures(t, ctx, sqlExec) + expected := []Asset{ + *CreateAssetFixture(t, ctx, sqlExec, "EURT", "GA62MH5RDXFWAIWHQEFNMO2SVDDCQLWOO3GO36VQB5LHUXL22DQ6IQAU"), + *CreateAssetFixture(t, ctx, sqlExec, "USDC", "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE"), + } + return expected +} + +func CreateDefaultWalletFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) *Wallet { + return CreateWalletFixture(t, ctx, sqlExec, "Demo Wallet", + "https://demo-wallet.stellar.org", + "https://demo-wallet.stellar.org", + "demo-wallet-server.stellar.org") +} + +func CreateWalletFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, name, homepage, sep10ClientDomain, deepLinkSchema string) *Wallet { + const query = ` + INSERT INTO wallets + (name, homepage, sep_10_client_domain, deep_link_schema) + VALUES + ($1, $2, $3, $4) + ON CONFLICT DO NOTHING + RETURNING + id, created_at, updated_at + + ` + + _, err := sqlExec.ExecContext(ctx, query, name, homepage, sep10ClientDomain, deepLinkSchema) + require.NoError(t, err) + + return GetWalletFixture(t, ctx, sqlExec, name) +} + +func GetWalletFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, name string) *Wallet { + const query = ` + SELECT + * + FROM + wallets w + WHERE + w.name = $1 + ` + + wallet := &Wallet{} + err := sqlExec.GetContext(ctx, wallet, query, name) + require.NoError(t, err) + + return wallet +} + +// DeleteAllWalletFixtures deletes all wallets in the database +func DeleteAllWalletFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM wallets" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +// ClearAndCreateWalletFixtures deletes all wallets in the database then creates new wallets for testing +func ClearAndCreateWalletFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) []Wallet { + DeleteAllWalletFixtures(t, ctx, sqlExec) + expected := []Wallet{ + *CreateWalletFixture(t, ctx, sqlExec, "BOSS Money", "https://www.walletbyboss.com", "www.walletbyboss.com", "https://www.walletbyboss.com"), + *CreateWalletFixture(t, ctx, sqlExec, "Vibrant Assist", "https://vibrantapp.com", "vibrantapp.com", "vibrantapp://"), + } + return expected +} + +func GetCountryFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, code string) *Country { + const query = ` + SELECT + * + FROM + countries + WHERE + code = $1 + ` + + country := &Country{} + err := sqlExec.GetContext(ctx, country, query, code) + require.NoError(t, err) + + return country +} + +func CreateCountryFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, code, name string) *Country { + const query = ` + WITH create_country AS ( + INSERT INTO countries + (code, name) + VALUES + ($1, $2) + ON CONFLICT DO NOTHING + RETURNING * + ) + SELECT created_at, updated_at FROM create_country + UNION ALL + SELECT created_at, updated_at FROM countries WHERE code = $1 AND name = $2 + ` + + country := &Country{ + Code: code, + Name: name, + } + + err := sqlExec.QueryRowxContext(ctx, query, code, name).Scan(&country.CreatedAt, &country.UpdatedAt) + require.NoError(t, err) + + return country +} + +// DeleteAllCountryFixtures deletes all countries in the database +func DeleteAllCountryFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM countries" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +// ClearAndCreateCountryFixtures deletes all countries in the database then creates new countries for testing +func ClearAndCreateCountryFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) []Country { + DeleteAllCountryFixtures(t, ctx, sqlExec) + expected := []Country{ + *CreateCountryFixture(t, ctx, sqlExec, "BRA", "Brazil"), + *CreateCountryFixture(t, ctx, sqlExec, "UKR", "Ukraine"), + } + return expected +} + +func CreateReceiverFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, r *Receiver) *Receiver { + randomSuffix, err := utils.RandomString(5) + require.NoError(t, err) + + if r.Email == nil { + email := fmt.Sprintf("email%s@randomemail.com", randomSuffix) + r.Email = &email + } + + if r.PhoneNumber == "" { + r.PhoneNumber = "+141555" + randomSuffix + } + + if r.ExternalID == "" { + r.ExternalID, err = utils.RandomString(56) + require.NoError(t, err) + } + + if r.CreatedAt == nil { + now := time.Now() + r.CreatedAt = &now + } + + if r.UpdatedAt == nil { + now := time.Now() + r.UpdatedAt = &now + } + + const query = ` + INSERT INTO receivers + (email, phone_number, external_id, created_at, updated_at) + VALUES + ($1, $2, $3, $4, $5) + RETURNING + id, email, phone_number, external_id, created_at, updated_at + ` + + var receiver Receiver + err = sqlExec.QueryRowxContext(ctx, query, r.Email, r.PhoneNumber, r.ExternalID, r.CreatedAt, r.UpdatedAt).Scan( + &receiver.ID, + &receiver.Email, + &receiver.PhoneNumber, + &receiver.ExternalID, + &receiver.CreatedAt, + &receiver.UpdatedAt, + ) + require.NoError(t, err) + + return &receiver +} + +func DeleteAllReceiversFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM receivers" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +func CreateReceiverVerificationFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, insert ReceiverVerificationInsert) *ReceiverVerification { + const query = ` + INSERT INTO receiver_verifications + (receiver_id, verification_field, hashed_value) + VALUES + ($1, $2, $3) + RETURNING + receiver_id, verification_field, hashed_value, attempts, created_at, confirmed_at, updated_at, failed_at + ` + + var verification ReceiverVerification + verificationValue, err := HashVerificationValue(insert.VerificationValue) + require.NoError(t, err) + + err = sqlExec.GetContext(ctx, &verification, query, insert.ReceiverID, insert.VerificationField, verificationValue) + require.NoError(t, err) + + return &verification +} + +func DeleteAllReceiverVerificationFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM receiver_verifications" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +func CreateReceiverWalletFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, receiverID, walletID string, status ReceiversWalletStatus) *ReceiverWallet { + kp, err := keypair.Random() + require.NoError(t, err) + stellarAddress := kp.Address() + + randNumber, err := rand.Int(rand.Reader, big.NewInt(90000)) + require.NoError(t, err) + + stellarMemo := fmt.Sprint(randNumber.Int64() + 10000) + stellarMemoType := "id" + + const query = ` + WITH inserted_receiver_wallet AS ( + INSERT INTO receiver_wallets + (receiver_id, wallet_id, stellar_address, stellar_memo, stellar_memo_type, status) + VALUES + ($1, $2, $3, $4, $5, $6) + RETURNING + id, receiver_id, wallet_id, stellar_address, stellar_memo, stellar_memo_type, status, status_history, created_at, updated_at + ) + SELECT + rw.id, rw.stellar_address, rw.stellar_memo, rw.stellar_memo_type, rw.status, rw.status_history, rw.created_at, rw.updated_at, + r.id, r.email, r.phone_number, r.external_id, r.created_at, r.updated_at, + w.id, w.name, w.homepage, w.deep_link_schema, w.created_at, w.updated_at + FROM + inserted_receiver_wallet AS rw + JOIN receivers AS r ON rw.receiver_id = r.id + JOIN wallets AS w ON rw.wallet_id = w.id + ` + + var statusHistoryJSON pq.ByteaArray + var receiverWallet ReceiverWallet + err = sqlExec.QueryRowxContext(ctx, query, receiverID, walletID, stellarAddress, stellarMemo, stellarMemoType, status).Scan( + &receiverWallet.ID, + &receiverWallet.StellarAddress, + &receiverWallet.StellarMemo, + &receiverWallet.StellarMemoType, + &receiverWallet.Status, + &statusHistoryJSON, + &receiverWallet.CreatedAt, + &receiverWallet.UpdatedAt, + &receiverWallet.Receiver.ID, + &receiverWallet.Receiver.Email, + &receiverWallet.Receiver.PhoneNumber, + &receiverWallet.Receiver.ExternalID, + &receiverWallet.Receiver.CreatedAt, + &receiverWallet.Receiver.UpdatedAt, + &receiverWallet.Wallet.ID, + &receiverWallet.Wallet.Name, + &receiverWallet.Wallet.Homepage, + &receiverWallet.Wallet.DeepLinkSchema, + &receiverWallet.Wallet.CreatedAt, + &receiverWallet.Wallet.UpdatedAt, + ) + require.NoError(t, err) + + err = receiverWallet.statusHistoryFromByteArray(statusHistoryJSON) + require.NoError(t, err) + + return &receiverWallet +} + +func DeleteAllReceiverWalletsFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = ` + DELETE FROM receiver_wallets + ` + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +func CreatePaymentFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, model *PaymentModel, p *Payment) *Payment { + if p.StatusHistory == nil { + p.StatusHistory = []PaymentStatusHistoryEntry{{ + Timestamp: time.Now(), + Status: p.Status, + StatusMessage: "", + }} + } + + if p.CreatedAt.IsZero() { + p.CreatedAt = time.Now() + } + + if p.UpdatedAt.IsZero() { + p.UpdatedAt = time.Now() + } + + const query = ` + INSERT INTO payments + (receiver_id, disbursement_id, receiver_wallet_id, asset_id, amount, status, status_history, + stellar_transaction_id, stellar_operation_id, created_at, updated_at) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + RETURNING + id + ` + var newId string + err := sqlExec.GetContext(ctx, &newId, query, + p.ReceiverWallet.Receiver.ID, + p.Disbursement.ID, + p.ReceiverWallet.ID, + p.Asset.ID, + p.Amount, + p.Status, + p.StatusHistory, + p.StellarTransactionID, + p.StellarOperationID, + p.CreatedAt, + p.UpdatedAt, + ) + require.NoError(t, err) + + // get payment + payment, err := model.Get(ctx, newId, sqlExec) + require.NoError(t, err) + return payment +} + +func DeleteAllPaymentsFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM payments" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +func CreateDisbursementFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, model *DisbursementModel, d *Disbursement) *Disbursement { + if d == nil { + d = &Disbursement{} + } + if d.Name == "" { + randomName, err := utils.RandomString(10) + require.NoError(t, err) + d.Name = randomName + } + if d.Status == "" { + d.Status = DraftDisbursementStatus + } + if d.Wallet == nil { + d.Wallet = CreateDefaultWalletFixture(t, ctx, sqlExec) + } + if d.Asset == nil { + d.Asset = GetAssetFixture(t, ctx, sqlExec, FixtureAssetUSDC) + } + if d.Country == nil { + d.Country = GetCountryFixture(t, ctx, sqlExec, FixtureCountryUKR) + } + // insert disbursement + if d.StatusHistory == nil { + d.StatusHistory = []DisbursementStatusHistoryEntry{{ + Timestamp: time.Now(), + Status: d.Status, + }} + } + id, err := model.Insert(ctx, d) + require.NoError(t, err) + + // update created_at + const query = ` + UPDATE disbursements + SET created_at = $1 + WHERE id = $2 + ` + _, err = sqlExec.ExecContext(ctx, query, d.CreatedAt, id) + require.NoError(t, err) + + // get disbursement + disbursement, err := model.Get(ctx, model.dbConnectionPool, id) + require.NoError(t, err) + return disbursement +} + +func UpdateDisbursementInstructionsFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, disbursementID, fileName string, instructions []*DisbursementInstruction) { + fileContent := CreateInstructionsFixture(t, instructions) + + const query = ` + UPDATE disbursements + SET file_name = $1, file_content = $2 + WHERE id = $3 + ` + _, err := sqlExec.ExecContext(ctx, query, fileName, fileContent, disbursementID) + require.NoError(t, err) +} + +func CreateInstructionsFixture(t *testing.T, instructions []*DisbursementInstruction) []byte { + // phone,id,amount,verification + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + + // write header + outerErr := writer.Write([]string{"phone", "id", "amount", "verification_value"}) + require.NoError(t, outerErr) + + // write instructions + for _, instruction := range instructions { + record := []string{instruction.Phone, instruction.ID, instruction.Amount, instruction.VerificationValue} + err := writer.Write(record) + require.NoError(t, err) + } + writer.Flush() + return buf.Bytes() +} + +func CreateDraftDisbursementFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, model *DisbursementModel, insert Disbursement) *Disbursement { + if insert.StatusHistory == nil { + insert.StatusHistory = []DisbursementStatusHistoryEntry{{ + Timestamp: time.Now(), + Status: DraftDisbursementStatus, + UserID: "user1", + }} + } + + if insert.Status == "" { + insert.Status = DraftDisbursementStatus + } + + id, err := model.Insert(ctx, &insert) + require.NoError(t, err) + + // get disbursement + disbursement, err := model.Get(ctx, sqlExec, id) + require.NoError(t, err) + return disbursement +} + +func DeleteAllDisbursementFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM disbursements" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +func CreateMessageFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, m *Message) *Message { + if m.TextEncrypted == "" { + m.TextEncrypted = "text encrypted" + } + + if m.TitleEncrypted == "" { + m.TitleEncrypted = "title encrypted" + } + + if m.CreatedAt.IsZero() { + m.CreatedAt = time.Now().UTC() + } + + const query = ` + INSERT INTO messages + ( + type, asset_id, receiver_id, wallet_id, receiver_wallet_id, + text_encrypted, title_encrypted, status, created_at, updated_at + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING + id, status_history + ` + + err := sqlExec.QueryRowxContext(ctx, query, m.Type, m.AssetID, m.ReceiverID, m.WalletID, m.ReceiverWalletID, m.TextEncrypted, m.TitleEncrypted, m.Status, m.CreatedAt, m.UpdatedAt).Scan( + &m.ID, + &m.StatusHistory, + ) + require.NoError(t, err) + + return m +} + +// EnableDisbursementApproval enables disbursement workflow approval for the given organization. +func EnableDisbursementApproval(t *testing.T, ctx context.Context, orgModel *OrganizationModel) { + isApprovalRequired := true + err := orgModel.Update(ctx, &OrganizationUpdate{IsApprovalRequired: &isApprovalRequired}) + require.NoError(t, err) +} + +// DisableDisbursementApproval disables disbursement workflow approval for the given organization. +func DisableDisbursementApproval(t *testing.T, ctx context.Context, orgModel *OrganizationModel) { + isApprovalRequired := false + err := orgModel.Update(ctx, &OrganizationUpdate{IsApprovalRequired: &isApprovalRequired}) + require.NoError(t, err) +} + +func DeleteAllMessagesFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = ` + DELETE FROM messages + ` + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +type ImageSize int + +const ( + ImageSizeSmall ImageSize = iota + ImageSizeMedium + ImageSizeLarge +) + +/* +CreateMockImage creates an RGBA image with the given proportion and size. +The size is defined by how many different colors are drawn in the image, +so the compression format (jpeg or png) will generate a larger file since +the image will have more complexity. Note: Depending on the compression format +the image size may vary. + +Example creating a file: + + img := CreateMockImage(t, 3840, 2160, ImageSizeLarge) + f, err := os.Create("image.png") + require.NoError(t, err) + err = jpeg.Encode(f, img, &jpeg.Options{Quality: jpeg.DefaultQuality} + require.NoError(t, err) + +Example in memory image: + + img := CreateMockImage(t, 1920, 1080, ImageSizeMedium) + buf := new(bytes.Buffer) + err = png.Encode(buf, img) + require.NoError(t, err) + fmt.Println(img.Bytes()) +*/ +func CreateMockImage(t *testing.T, width, height int, size ImageSize) image.Image { + imgRect := image.Rect(0, 0, width, height) + img := image.NewRGBA(imgRect) + + bigInt := big.NewInt(255) + + // sets a random color for every pixel. It increase the compression complexity. + largeImageColor := func() color.Color { + r, err := rand.Int(rand.Reader, bigInt) + require.NoError(t, err) + + g, err := rand.Int(rand.Reader, bigInt) + require.NoError(t, err) + + b, err := rand.Int(rand.Reader, bigInt) + require.NoError(t, err) + + return color.RGBA{uint8(r.Int64()), uint8(g.Int64()), uint8(b.Int64()), 255} + } + + // sets the same color for each line. It's less complex than the largeImageColor. + mediumImageColor := func() color.Color { + n, err := rand.Int(rand.Reader, bigInt) + require.NoError(t, err) + + return color.RGBA{uint8(n.Int64()), uint8(n.Int64()), uint8(n.Int64()), 255} + } + + // sets the same color for the entire image. No complexity. + smallImageColor := func() color.Color { + // returns the cyan color + return color.RGBA{100, 200, 200, 0xff} + } + + var c color.Color + for x := 0; x < width; x++ { + if size == ImageSizeMedium { + c = mediumImageColor() + } + + for y := 0; y < height; y++ { + switch size { + case ImageSizeSmall: + c = smallImageColor() + case ImageSizeLarge: + c = largeImageColor() + } + + img.Set(x, y, c) + } + } + + return img +} diff --git a/internal/data/fixtures_test.go b/internal/data/fixtures_test.go new file mode 100644 index 000000000..f48136f00 --- /dev/null +++ b/internal/data/fixtures_test.go @@ -0,0 +1,133 @@ +package data + +import ( + "context" + "strings" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/require" +) + +func Test_CreateReceiverFixture(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + // Create a random receiver + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + require.Len(t, receiver.ID, 36) + require.NotEmpty(t, receiver.Email) + require.NotEmpty(t, receiver.PhoneNumber) + require.NotEmpty(t, receiver.ExternalID) + require.NotEmpty(t, receiver.CreatedAt) + require.NotEmpty(t, receiver.UpdatedAt) +} + +func Test_CreateReceiverWalletFixture(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + // Create a random receiver wallet + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "My Wallet", "https://mywallet.test.com/", "mywallet.test.com", "mtwallet://") + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + rw := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + // Check receiver wallet + require.Len(t, rw.ID, 36) + require.NotEmpty(t, rw.StellarAddress) + require.NotEmpty(t, rw.StellarMemo) + require.NotEmpty(t, rw.StellarMemoType) + require.Equal(t, DraftReceiversWalletStatus, rw.Status) + require.Len(t, rw.StatusHistory, 1) + require.NotEmpty(t, rw.StatusHistory[0].Timestamp) + require.Equal(t, DraftReceiversWalletStatus, rw.StatusHistory[0].Status) + require.NotEmpty(t, rw.CreatedAt) + require.NotEmpty(t, rw.UpdatedAt) + + // Check receiver + require.Len(t, rw.Receiver.ID, 36) + require.Equal(t, receiver.ID, rw.Receiver.ID) + require.NotEmpty(t, rw.Receiver.Email) + require.NotEmpty(t, rw.Receiver.PhoneNumber) + require.NotEmpty(t, rw.Receiver.ExternalID) + require.NotEmpty(t, rw.Receiver.CreatedAt) + require.NotEmpty(t, rw.Receiver.UpdatedAt) + + // Check wallet + require.Len(t, rw.Wallet.ID, 36) + require.Equal(t, wallet.ID, rw.Wallet.ID) + require.NotEmpty(t, rw.Wallet.Name) + require.NotEmpty(t, rw.Wallet.Homepage) + require.NotEmpty(t, rw.Wallet.DeepLinkSchema) + require.NotEmpty(t, rw.Wallet.CreatedAt) + require.NotEmpty(t, rw.Wallet.UpdatedAt) +} + +func Test_Fixtures_CreateInstructionsFixture(t *testing.T) { + t.Run("header only for nil instructions", func(t *testing.T) { + fileContent := CreateInstructionsFixture(t, nil) + lines := strings.Split(string(fileContent), "\n") + require.Equal(t, "phone,id,amount,verification_value", lines[0]) + }) + + t.Run("header only for empty instructions", func(t *testing.T) { + buf := CreateInstructionsFixture(t, []*DisbursementInstruction{}) + require.Equal(t, "phone,id,amount,verification_value\n", string(buf)) + }) + + t.Run("writes records correctly", func(t *testing.T) { + instructions := []*DisbursementInstruction{ + {"1234567890", "1", "123.12", "1995-02-20"}, + {"0987654321", "2", "321", "1974-07-19"}, + } + buf := CreateInstructionsFixture(t, instructions) + lines := strings.Split(string(buf), "\n") + require.Equal(t, "1234567890,1,123.12,1995-02-20", lines[1]) + require.Equal(t, "0987654321,2,321,1974-07-19", lines[2]) + }) +} + +func Test_Fixtures_UpdateDisbursementInstructionsFixture(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + disbursementModel := &DisbursementModel{dbConnectionPool: dbConnectionPool} + + disbursement := CreateDisbursementFixture(t, ctx, dbConnectionPool, &DisbursementModel{dbConnectionPool: dbConnectionPool}, &Disbursement{ + Name: "disbursement1", + }) + + instructions := []*DisbursementInstruction{ + {"1234567890", "1", "123.12", "1995-02-20"}, + {"0987654321", "2", "321", "1974-07-19"}, + {"0987654321", "3", "321", "1974-07-19"}, + } + + t.Run("update instructions", func(t *testing.T) { + UpdateDisbursementInstructionsFixture(t, ctx, dbConnectionPool, disbursement.ID, "test.csv", instructions) + actual, err := disbursementModel.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + require.Equal(t, "test.csv", actual.FileName) + require.NotEmpty(t, actual.FileContent) + expected := CreateInstructionsFixture(t, instructions) + require.Equal(t, expected, actual.FileContent) + }) +} diff --git a/internal/data/messages.go b/internal/data/messages.go new file mode 100644 index 000000000..492b8d195 --- /dev/null +++ b/internal/data/messages.go @@ -0,0 +1,191 @@ +package data + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "time" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" +) + +type MessageStatus string + +var ( + PendingMessageStatus MessageStatus = "PENDING" + SuccessMessageStatus MessageStatus = "SUCCESS" + FailureMessageStatus MessageStatus = "FAILURE" +) + +type MessageModel struct { + dbConnectionPool db.DBConnectionPool +} + +type Message struct { + ID string `db:"id"` + Type message.MessengerType `db:"type"` + AssetID *string `db:"asset_id"` + ReceiverID string `db:"receiver_id"` + WalletID string `db:"wallet_id"` + ReceiverWalletID *string `db:"receiver_wallet_id"` + TextEncrypted string `db:"text_encrypted"` + TitleEncrypted string `db:"title_encrypted"` + Status MessageStatus `db:"status"` + StatusHistory MessageStatusHistory `db:"status_history"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +type MessageInsert struct { + Type message.MessengerType + AssetID *string + ReceiverID string + WalletID string + ReceiverWalletID *string + TextEncrypted string + TitleEncrypted string + Status MessageStatus + StatusHistory MessageStatusHistory +} + +type MessageStatusHistoryEntry struct { + StatusMessage *string `json:"status_message"` + Status MessageStatus `json:"status"` + Timestamp time.Time `json:"timestamp"` +} + +type MessageStatusHistory []MessageStatusHistoryEntry + +// Value implements the driver.Valuer interface. +func (msh MessageStatusHistory) Value() (driver.Value, error) { + var statusHistoryJSON []string + for _, sh := range msh { + shJSONBytes, err := json.Marshal(sh) + if err != nil { + return nil, fmt.Errorf("error converting status history to json for message: %w", err) + } + statusHistoryJSON = append(statusHistoryJSON, string(shJSONBytes)) + } + + return pq.Array(statusHistoryJSON).Value() +} + +// Scan implements the sql.Scanner interface. +func (msh *MessageStatusHistory) Scan(src interface{}) error { + var statusHistoryJSON []string + if err := pq.Array(&statusHistoryJSON).Scan(src); err != nil { + return fmt.Errorf("error scanning status history value: %w", err) + } + + for _, sh := range statusHistoryJSON { + var shEntry MessageStatusHistoryEntry + err := json.Unmarshal([]byte(sh), &shEntry) + if err != nil { + return fmt.Errorf("error unmarshaling status_history column: %w", err) + } + *msh = append(*msh, shEntry) + } + + return nil +} + +func (m *MessageModel) Insert(ctx context.Context, newMsg *MessageInsert) (*Message, error) { + const query = ` + INSERT INTO messages + ( + type, asset_id, receiver_id, wallet_id, receiver_wallet_id, + text_encrypted, title_encrypted, status, status_history + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING + * + ` + var msg Message + err := m.dbConnectionPool.GetContext(ctx, &msg, query, newMsg.Type, newMsg.AssetID, newMsg.ReceiverID, newMsg.WalletID, newMsg.ReceiverWalletID, newMsg.TextEncrypted, newMsg.TitleEncrypted, newMsg.Status, newMsg.StatusHistory) + if err != nil { + return nil, fmt.Errorf("error inserting message: %w", err) + } + + return &msg, nil +} + +func (m *MessageModel) BulkInsert(ctx context.Context, newMsgs []*MessageInsert) error { + var ( + types, receiverIDs, walletIDs pq.StringArray + encryptedTexts, encryptedTitles, statuses pq.StringArray + assetIDs, receiverWalletIDs []sql.NullString + ) + + for _, newMsg := range newMsgs { + types = append(types, string(newMsg.Type)) + + assetID := "" + if newMsg.AssetID != nil { + assetID = *newMsg.AssetID + } + assetIDs = append(assetIDs, sql.NullString{ + String: assetID, + Valid: (newMsg.AssetID != nil && *newMsg.AssetID != ""), + }) + + receiverIDs = append(receiverIDs, newMsg.ReceiverID) + walletIDs = append(walletIDs, newMsg.WalletID) + + receiverWalletID := "" + if newMsg.ReceiverWalletID != nil { + receiverWalletID = *newMsg.ReceiverWalletID + } + receiverWalletIDs = append(receiverWalletIDs, sql.NullString{ + String: receiverWalletID, + Valid: (newMsg.ReceiverWalletID != nil && *newMsg.ReceiverWalletID != ""), + }) + + encryptedTexts = append(encryptedTexts, newMsg.TextEncrypted) + encryptedTitles = append(encryptedTitles, newMsg.TitleEncrypted) + statuses = append(statuses, string(newMsg.Status)) + } + + return db.RunInTransaction(ctx, m.dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + const insertQuery = ` + INSERT INTO messages + ( + type, asset_id, receiver_id, wallet_id, receiver_wallet_id, + text_encrypted, title_encrypted, status + ) + SELECT + UNNEST($1::message_type[]) AS type, UNNEST($2::text[]) AS asset_id, UNNEST($3::text[]) AS receiver_id, UNNEST($4::text[]) AS wallet_id, + UNNEST($5::text[]) AS receiver_wallet_id, UNNEST($6::text[]) AS text_encrypted, UNNEST($7::text[]) AS title_encrypted, + UNNEST($8::message_status[]) AS status + RETURNING + id + ` + + var newMsgIDs []string + err := dbTx.SelectContext(ctx, &newMsgIDs, insertQuery, types, pq.Array(assetIDs), receiverIDs, walletIDs, pq.Array(receiverWalletIDs), encryptedTexts, encryptedTitles, statuses) + if err != nil { + return fmt.Errorf("error inserting messages in BulkInsert: %w", err) + } + + const updateQuery = ` + UPDATE + messages + SET + status_history = array_append(status_history, create_message_status_history(updated_at, status, NULL)), + updated_at = NOW() + WHERE + id = ANY($1::text[]) + AND status != 'PENDING' + ` + _, err = dbTx.ExecContext(ctx, updateQuery, pq.Array(newMsgIDs)) + if err != nil { + return fmt.Errorf("error update messages status history: %w", err) + } + + return nil + }) +} diff --git a/internal/data/messages_test.go b/internal/data/messages_test.go new file mode 100644 index 000000000..90288df42 --- /dev/null +++ b/internal/data/messages_test.go @@ -0,0 +1,172 @@ +package data + +import ( + "context" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_MessageModel_Insert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + t.Run("inserts a new message successfully", func(t *testing.T) { + mm := &MessageModel{dbConnectionPool: dbConnectionPool} + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + statusHistory := MessageStatusHistory{ + { + Status: PendingMessageStatus, + Timestamp: time.Now().UTC(), + }, + } + + msg, err := mm.Insert(ctx, &MessageInsert{ + Type: message.MessengerTypeTwilioSMS, + AssetID: &asset.ID, + ReceiverID: receiver.ID, + WalletID: wallet.ID, + TextEncrypted: "text encrypted", + TitleEncrypted: "title encrypted", + Status: PendingMessageStatus, + StatusHistory: statusHistory, + }) + require.NoError(t, err) + + assert.NotEmpty(t, msg.ID) + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, PendingMessageStatus, msg.Status) + assert.Equal(t, statusHistory, msg.StatusHistory) + assert.Equal(t, asset.ID, *msg.AssetID) + assert.Equal(t, receiver.ID, msg.ReceiverID) + assert.Equal(t, wallet.ID, msg.WalletID) + assert.Equal(t, "text encrypted", msg.TextEncrypted) + assert.Equal(t, "title encrypted", msg.TitleEncrypted) + assert.NotEmpty(t, msg.CreatedAt) + assert.NotEmpty(t, msg.UpdatedAt) + }) +} + +func Test_MessageModel_BulkInsert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + mm := &MessageModel{dbConnectionPool: dbConnectionPool} + + t.Run("inserts a new messages successfully", func(t *testing.T) { + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + err := mm.BulkInsert(ctx, []*MessageInsert{ + { + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet.ID, + TextEncrypted: "text encrypted", + TitleEncrypted: "title encrypted", + Status: SuccessMessageStatus, + }, + { + Type: message.MessengerTypeTwilioSMS, + AssetID: &asset.ID, + ReceiverID: receiver.ID, + WalletID: wallet.ID, + ReceiverWalletID: nil, + TextEncrypted: "text encrypted", + TitleEncrypted: "title encrypted", + Status: PendingMessageStatus, + }, + { + Type: message.MessengerTypeTwilioSMS, + AssetID: &asset.ID, + ReceiverID: receiver.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet.ID, + TextEncrypted: "text encrypted", + TitleEncrypted: "title encrypted", + Status: FailureMessageStatus, + }, + }) + require.NoError(t, err) + + const q = ` + SELECT + id, type, asset_id, receiver_id, wallet_id, receiver_wallet_id, + text_encrypted, title_encrypted, status, status_history, + created_at, updated_at + FROM + messages + ORDER BY + status::text + ` + + var messages []Message + err = dbConnectionPool.SelectContext(ctx, &messages, q) + require.NoError(t, err) + + assert.Len(t, messages, 3) + + // Failure + assert.Equal(t, message.MessengerTypeTwilioSMS, messages[0].Type) + assert.Equal(t, asset.ID, *messages[0].AssetID) + assert.Equal(t, receiver.ID, messages[0].ReceiverID) + assert.Equal(t, wallet.ID, messages[0].WalletID) + assert.Equal(t, message.MessengerTypeTwilioSMS, messages[0].Type) + assert.Equal(t, receiverWallet.ID, *messages[0].ReceiverWalletID) + assert.Equal(t, "text encrypted", messages[0].TextEncrypted) + assert.Equal(t, "title encrypted", messages[0].TitleEncrypted) + assert.Equal(t, FailureMessageStatus, messages[0].Status) + assert.Len(t, messages[0].StatusHistory, 2) + assert.Equal(t, PendingMessageStatus, messages[0].StatusHistory[0].Status) + assert.Equal(t, FailureMessageStatus, messages[0].StatusHistory[1].Status) + + // Pending + assert.Equal(t, message.MessengerTypeTwilioSMS, messages[1].Type) + assert.Equal(t, asset.ID, *messages[1].AssetID) + assert.Equal(t, receiver.ID, messages[1].ReceiverID) + assert.Equal(t, wallet.ID, messages[1].WalletID) + assert.Equal(t, message.MessengerTypeTwilioSMS, messages[1].Type) + assert.Nil(t, messages[1].ReceiverWalletID) + assert.Equal(t, "text encrypted", messages[1].TextEncrypted) + assert.Equal(t, "title encrypted", messages[1].TitleEncrypted) + assert.Equal(t, PendingMessageStatus, messages[1].Status) + assert.Len(t, messages[1].StatusHistory, 1) + assert.Equal(t, PendingMessageStatus, messages[1].StatusHistory[0].Status) + + // Success + assert.Equal(t, message.MessengerTypeTwilioSMS, messages[2].Type) + assert.Nil(t, messages[2].AssetID) + assert.Equal(t, receiver.ID, messages[2].ReceiverID) + assert.Equal(t, wallet.ID, messages[2].WalletID) + assert.Equal(t, message.MessengerTypeTwilioSMS, messages[2].Type) + assert.Equal(t, receiverWallet.ID, *messages[2].ReceiverWalletID) + assert.Equal(t, "text encrypted", messages[2].TextEncrypted) + assert.Equal(t, "title encrypted", messages[2].TitleEncrypted) + assert.Equal(t, SuccessMessageStatus, messages[2].Status) + assert.Len(t, messages[2].StatusHistory, 2) + assert.Equal(t, PendingMessageStatus, messages[2].StatusHistory[0].Status) + assert.Equal(t, SuccessMessageStatus, messages[2].StatusHistory[1].Status) + }) +} diff --git a/internal/data/models.go b/internal/data/models.go new file mode 100644 index 000000000..e8946829c --- /dev/null +++ b/internal/data/models.go @@ -0,0 +1,51 @@ +package data + +import ( + "errors" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +var ( + ErrRecordNotFound = errors.New("record not found") + ErrRecordAlreadyExists = errors.New("record already exists") + ErrMismatchNumRowsAffected = errors.New("mismatch number of rows affected") + ErrMissingInput = errors.New("missing input") +) + +type Models struct { + Disbursements *DisbursementModel + Wallets *WalletModel + Countries *CountryModel + Assets *AssetModel + Organizations *OrganizationModel + Payment *PaymentModel + Receiver *ReceiverModel + DisbursementInstructions *DisbursementInstructionModel + ReceiverVerification *ReceiverVerificationModel + ReceiverWallet *ReceiverWalletModel + DisbursementReceivers *DisbursementReceiverModel + Message *MessageModel + DBConnectionPool db.DBConnectionPool +} + +func NewModels(dbConnectionPool db.DBConnectionPool) (*Models, error) { + if dbConnectionPool == nil { + return nil, errors.New("dbConnectionPool is required for NewModels") + } + return &Models{ + Disbursements: &DisbursementModel{dbConnectionPool: dbConnectionPool}, + Wallets: &WalletModel{dbConnectionPool: dbConnectionPool}, + Countries: &CountryModel{dbConnectionPool: dbConnectionPool}, + Assets: &AssetModel{dbConnectionPool: dbConnectionPool}, + Organizations: &OrganizationModel{dbConnectionPool: dbConnectionPool}, + Payment: &PaymentModel{dbConnectionPool: dbConnectionPool}, + Receiver: &ReceiverModel{}, + DisbursementInstructions: NewDisbursementInstructionModel(dbConnectionPool), + ReceiverVerification: &ReceiverVerificationModel{}, + ReceiverWallet: &ReceiverWalletModel{dbConnectionPool: dbConnectionPool}, + DisbursementReceivers: &DisbursementReceiverModel{dbConnectionPool: dbConnectionPool}, + Message: &MessageModel{dbConnectionPool: dbConnectionPool}, + DBConnectionPool: dbConnectionPool, + }, nil +} diff --git a/internal/data/models_test.go b/internal/data/models_test.go new file mode 100644 index 000000000..6f8b00830 --- /dev/null +++ b/internal/data/models_test.go @@ -0,0 +1,30 @@ +package data + +import ( + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/require" +) + +func Test_NewModels(t *testing.T) { + t.Run("returns error if the db connection pool is nil", func(t *testing.T) { + models, err := NewModels(nil) + require.Nil(t, models) + require.EqualError(t, err, "dbConnectionPool is required for NewModels") + }) + + t.Run("returns model successfully πŸŽ‰", func(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := NewModels(dbConnectionPool) + require.NoError(t, err) + require.NotNil(t, models) + }) +} diff --git a/internal/data/organizations.go b/internal/data/organizations.go new file mode 100644 index 000000000..8563dadcb --- /dev/null +++ b/internal/data/organizations.go @@ -0,0 +1,184 @@ +package data + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "image" + "regexp" + + // Don't remove the `image/jpeg` and `image/png` packages import unless + // the `image` package is no longer necessary. + // It registers the `Decoders` to handle the image decoding - `image.Decode`. + // See https://pkg.go.dev/image#pkg-overview + _ "image/jpeg" + _ "image/png" + "strings" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type Organization struct { + ID string `json:"id" db:"id"` + Name string `json:"name" db:"name"` + StellarMainAddress string `json:"stellar_main_address" db:"stellar_main_address"` + TimezoneUTCOffset string `json:"timezone_utc_offset" db:"timezone_utc_offset"` + ArePaymentsEnabled bool `json:"are_payments_enabled" db:"are_payments_enabled"` + SMSRegistrationMessageTemplate string `json:"sms_registration_message_template" db:"sms_registration_message_template"` + Logo []byte `db:"logo"` + IsApprovalRequired bool `json:"is_approval_required" db:"is_approval_required"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type OrganizationUpdate struct { + Name string + Logo []byte + TimezoneUTCOffset string + IsApprovalRequired *bool +} + +type LogoType string + +const ( + PNGLogoType LogoType = "png" + JPEGLogoType LogoType = "jpeg" + + // tzRegexExpression validates the TimezoneUTCOffset value. It expects the following format: + // plus or minus symbol + two numbers + colon symbol + two numbers + // Example: + // +02:00 or -03:00 + // Any other value will be invalid. + tzRegexExpression string = `^(\+|-)\d{2}:\d{2}$` +) + +var tzRegex *regexp.Regexp + +func init() { + tzRegex = regexp.MustCompile(tzRegexExpression) +} + +func (lt LogoType) ToHTTPContentType() string { + return fmt.Sprintf("image/%s", lt) +} + +func (ou *OrganizationUpdate) validate() error { + if ou.Name == "" && len(ou.Logo) == 0 && ou.TimezoneUTCOffset == "" && ou.IsApprovalRequired == nil { + return fmt.Errorf("name, timezone UTC offset, approval workflow flag or logo is required") + } + + if len(ou.Logo) > 0 { + _, format, err := image.Decode(bytes.NewBuffer(ou.Logo)) + if err != nil { + return fmt.Errorf("error decoding image bytes: %w", err) + } + + if !strings.Contains(fmt.Sprintf("%s %s", PNGLogoType, JPEGLogoType), format) { + return fmt.Errorf("invalid image type provided. Expect %s or %s", PNGLogoType, JPEGLogoType) + } + } + + if ou.TimezoneUTCOffset != "" && !tzRegex.MatchString(ou.TimezoneUTCOffset) { + return fmt.Errorf("invalid timezone UTC offset format. Example: +02:00 or -03:00") + } + + return nil +} + +type OrganizationModel struct { + dbConnectionPool db.DBConnectionPool +} + +func (om *OrganizationModel) Get(ctx context.Context) (*Organization, error) { + var organization Organization + query := ` + SELECT + * + FROM + organizations o + LIMIT 1 + ` + + err := om.dbConnectionPool.GetContext(ctx, &organization, query) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying organization table: %w", err) + } + + return &organization, nil +} + +func (om *OrganizationModel) ArePaymentsEnabled(ctx context.Context) (bool, error) { + var arePaymentsEnabled bool + query := ` + SELECT + o.are_payments_enabled + FROM + organizations o + LIMIT 1 + ` + + err := om.dbConnectionPool.GetContext(ctx, &arePaymentsEnabled, query) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, ErrRecordNotFound + } + return false, fmt.Errorf("error querying organization table: %w", err) + } + + return arePaymentsEnabled, nil +} + +func (om *OrganizationModel) Update(ctx context.Context, ou *OrganizationUpdate) error { + if err := ou.validate(); err != nil { + return fmt.Errorf("invalid organization update: %w", err) + } + + query := ` + WITH org_cte AS ( + SELECT id FROM organizations LIMIT 1 + ) + UPDATE + organizations + SET + %s + FROM org_cte + WHERE organizations.id = org_cte.id + ` + + args := []interface{}{} + fields := []string{} + if ou.Name != "" { + fields = append(fields, "name = ?") + args = append(args, ou.Name) + } + + if len(ou.Logo) > 0 { + fields = append(fields, "logo = ?") + args = append(args, ou.Logo) + } + + if ou.TimezoneUTCOffset != "" { + fields = append(fields, "timezone_utc_offset = ?") + args = append(args, ou.TimezoneUTCOffset) + } + + if ou.IsApprovalRequired != nil { + fields = append(fields, "is_approval_required = ?") + args = append(args, *ou.IsApprovalRequired) + } + + query = om.dbConnectionPool.Rebind(fmt.Sprintf(query, strings.Join(fields, ", "))) + + _, err := om.dbConnectionPool.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("error updating organization: %w", err) + } + + return nil +} diff --git a/internal/data/organizations_test.go b/internal/data/organizations_test.go new file mode 100644 index 000000000..f9514727c --- /dev/null +++ b/internal/data/organizations_test.go @@ -0,0 +1,353 @@ +package data + +import ( + "bytes" + "context" + "encoding/csv" + "image/gif" + "image/jpeg" + "image/png" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Organizations_DatabaseTriggers(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + t.Run("SQL query will trigger an error if you try to have more than one organization", func(t *testing.T) { + q := ` + INSERT INTO organizations ( + name, stellar_main_address, timezone_utc_offset, are_payments_enabled, sms_registration_message_template + ) + VALUES ( + 'Test name', 'Test Stellar address', '+00:00', false, 'Test template {{.OrganizationName}} {{.RegistrationLink}}.' + ) + ` + _, err := dbConnectionPool.ExecContext(ctx, q) + require.EqualError(t, err, "pq: public.organizations can must contain exactly one row") + }) + + t.Run("SQL query will trigger an error if you try to delete the one organization you must have", func(t *testing.T) { + q := "DELETE FROM organizations" + _, err := dbConnectionPool.ExecContext(ctx, q) + require.EqualError(t, err, "pq: public.organizations can must contain exactly one row") + }) + + t.Run("updating sms_registration_message_template without the tags {{.OrganizationName}} and {{.RegistrationLink}} will trigger an error", func(t *testing.T) { + q := "UPDATE organizations SET sms_registration_message_template = 'Test template'" + _, err := dbConnectionPool.ExecContext(ctx, q) + require.EqualError(t, err, `pq: new row for relation "organizations" violates check constraint "organization_sms_registration_message_template_contains_tags_ch"`) + }) + t.Run("updating sms_registration_message_template with the tags {{.OrganizationName}} and {{.RegistrationLink}} will succeed πŸŽ‰", func(t *testing.T) { + q := "UPDATE organizations SET sms_registration_message_template = 'TAG1: {{.OrganizationName}} and TAG2: {{.RegistrationLink}}.'" + _, err := dbConnectionPool.ExecContext(ctx, q) + require.NoError(t, err) + }) +} + +func Test_Organizations_Get(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + organizationModel := &OrganizationModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns the single organization", func(t *testing.T) { + gotOrganization, err := organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Len(t, gotOrganization.ID, 36) + assert.Equal(t, "MyCustomAid", gotOrganization.Name) + assert.Equal(t, "GDA34JZ26FZY64XCSY46CUNSHLX762LHJXQHWWHGL5HSFRWSGBVHUFNI", gotOrganization.StellarMainAddress) + assert.Equal(t, "+00:00", gotOrganization.TimezoneUTCOffset) + assert.False(t, gotOrganization.ArePaymentsEnabled) + assert.Equal(t, "You have a payment waiting for you from the {{.OrganizationName}}. Click {{.RegistrationLink}} to register.", gotOrganization.SMSRegistrationMessageTemplate) + assert.NotEmpty(t, gotOrganization.CreatedAt) + assert.NotEmpty(t, gotOrganization.UpdatedAt) + assert.False(t, gotOrganization.IsApprovalRequired) + }) +} + +func Test_Organizations_ArePaymentsEnabled(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + organizationModel := &OrganizationModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns false if it's not enabled", func(t *testing.T) { + arePaymentsEnabled, err := organizationModel.ArePaymentsEnabled(ctx) + require.NoError(t, err) + require.False(t, arePaymentsEnabled) + }) + + t.Run("returns true if it's enabled", func(t *testing.T) { + q := "UPDATE organizations SET are_payments_enabled = true" + _, err := dbConnectionPool.ExecContext(ctx, q) + require.NoError(t, err) + + arePaymentsEnabled, err := organizationModel.ArePaymentsEnabled(ctx) + require.NoError(t, err) + require.True(t, arePaymentsEnabled) + }) +} + +func Test_OrganizationUpdate_validate(t *testing.T) { + ou := &OrganizationUpdate{} + err := ou.validate() + assert.EqualError(t, err, "name, timezone UTC offset, approval workflow flag or logo is required") + + ou.Name = "My Org Name" + err = ou.validate() + assert.Nil(t, err) + + // png + img := CreateMockImage(t, 300, 300, ImageSizeSmall) + buf := new(bytes.Buffer) + err = png.Encode(buf, img) + require.NoError(t, err) + + ou.Name = "" + ou.Logo = buf.Bytes() + err = ou.validate() + assert.Nil(t, err) + + // jpeg + img = CreateMockImage(t, 300, 300, ImageSizeSmall) + buf = new(bytes.Buffer) + err = jpeg.Encode(buf, img, &jpeg.Options{Quality: jpeg.DefaultQuality}) + require.NoError(t, err) + + ou.Name = "" + ou.Logo = buf.Bytes() + err = ou.validate() + assert.Nil(t, err) + + ou.Name = "My Org Name" + ou.Logo = buf.Bytes() + err = ou.validate() + assert.Nil(t, err) + + // error decoding image + csvBuf := new(bytes.Buffer) + csvWriter := csv.NewWriter(csvBuf) + err = csvWriter.WriteAll([][]string{ + {"name", "age"}, + {"foo", "99"}, + {"bar", "99"}, + }) + require.NoError(t, err) + + ou.Logo = csvBuf.Bytes() + err = ou.validate() + assert.EqualError(t, err, "error decoding image bytes: image: unknown format") + + // invalid image type + img = CreateMockImage(t, 300, 300, ImageSizeSmall) + buf = new(bytes.Buffer) + err = gif.Encode(buf, img, &gif.Options{}) + require.NoError(t, err) + + ou.Logo = buf.Bytes() + err = ou.validate() + assert.EqualError(t, err, "invalid image type provided. Expect png or jpeg") + + // timezone UTC offset + ou = &OrganizationUpdate{} + + ou.TimezoneUTCOffset = "0" + err = ou.validate() + assert.EqualError(t, err, "invalid timezone UTC offset format. Example: +02:00 or -03:00") + + ou.TimezoneUTCOffset = "+0" + err = ou.validate() + assert.EqualError(t, err, "invalid timezone UTC offset format. Example: +02:00 or -03:00") + + ou.TimezoneUTCOffset = "-5:00" + err = ou.validate() + assert.EqualError(t, err, "invalid timezone UTC offset format. Example: +02:00 or -03:00") + + ou.TimezoneUTCOffset = "-5:0" + err = ou.validate() + assert.EqualError(t, err, "invalid timezone UTC offset format. Example: +02:00 or -03:00") + + ou.TimezoneUTCOffset = "+03:01515515151551515" + err = ou.validate() + assert.EqualError(t, err, "invalid timezone UTC offset format. Example: +02:00 or -03:00") + + ou.TimezoneUTCOffset = "03:00" + err = ou.validate() + assert.EqualError(t, err, "invalid timezone UTC offset format. Example: +02:00 or -03:00") + + ou.TimezoneUTCOffset = "+05:00" + err = ou.validate() + assert.Nil(t, err) + + ou.TimezoneUTCOffset = "-02:00" + err = ou.validate() + assert.Nil(t, err) +} + +func Test_Organizations_Update(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + resetOrganizationInfo := func(t *testing.T, ctx context.Context) { + const q = "UPDATE organizations SET name = 'MyCustomAid', logo = NULL, timezone_utc_offset = '+00:00'" + _, err := dbConnectionPool.ExecContext(ctx, q) + require.NoError(t, err) + } + + organizationModel := &OrganizationModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error with invalid OrganizationUpdate", func(t *testing.T) { + ou := &OrganizationUpdate{} + err := organizationModel.Update(ctx, ou) + assert.EqualError(t, err, "invalid organization update: name, timezone UTC offset, approval workflow flag or logo is required") + }) + + t.Run("updates only organization's name successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + o, err := organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "MyCustomAid", o.Name) + assert.Equal(t, "+00:00", o.TimezoneUTCOffset) + assert.Nil(t, o.Logo) + + ou := &OrganizationUpdate{Name: "My Org Name"} + + err = organizationModel.Update(ctx, ou) + require.NoError(t, err) + + o, err = organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "My Org Name", o.Name) + assert.Equal(t, "+00:00", o.TimezoneUTCOffset) + assert.Nil(t, o.Logo) + }) + + t.Run("updates only organization's timezone UTC offset successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + o, err := organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "+00:00", o.TimezoneUTCOffset) + assert.Equal(t, "MyCustomAid", o.Name) + assert.Nil(t, o.Logo) + + ou := &OrganizationUpdate{TimezoneUTCOffset: "+02:00"} + + err = organizationModel.Update(ctx, ou) + require.NoError(t, err) + + o, err = organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "+02:00", o.TimezoneUTCOffset) + assert.Equal(t, "MyCustomAid", o.Name) + assert.Nil(t, o.Logo) + }) + + t.Run("updates only organization's logo successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + o, err := organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "MyCustomAid", o.Name) + assert.Nil(t, o.Logo) + + img := CreateMockImage(t, 300, 300, ImageSizeSmall) + buf := new(bytes.Buffer) + err = png.Encode(buf, img) + require.NoError(t, err) + + ou := &OrganizationUpdate{Logo: buf.Bytes()} + + err = organizationModel.Update(ctx, ou) + require.NoError(t, err) + + o, err = organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "MyCustomAid", o.Name) + assert.Equal(t, ou.Logo, o.Logo) + }) + + t.Run("updates only organization's is_approval_required successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + o, err := organizationModel.Get(ctx) + require.NoError(t, err) + require.False(t, o.IsApprovalRequired) + + isApprovalRequired := true + ou := &OrganizationUpdate{IsApprovalRequired: &isApprovalRequired} + + err = organizationModel.Update(ctx, ou) + require.NoError(t, err) + + o, err = organizationModel.Get(ctx) + require.NoError(t, err) + require.True(t, o.IsApprovalRequired) + }) + + t.Run("updates organization's name, timezone UTC offset and logo successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + o, err := organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "MyCustomAid", o.Name) + assert.Equal(t, "+00:00", o.TimezoneUTCOffset) + assert.Nil(t, o.Logo) + + img := CreateMockImage(t, 300, 300, ImageSizeSmall) + buf := new(bytes.Buffer) + err = png.Encode(buf, img) + require.NoError(t, err) + + ou := &OrganizationUpdate{Name: "My Org Name", Logo: buf.Bytes(), TimezoneUTCOffset: "+02:00"} + + err = organizationModel.Update(ctx, ou) + require.NoError(t, err) + + o, err = organizationModel.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "My Org Name", o.Name) + assert.Equal(t, "+02:00", o.TimezoneUTCOffset) + assert.Equal(t, ou.Logo, o.Logo) + }) +} diff --git a/internal/data/otp.go b/internal/data/otp.go new file mode 100644 index 000000000..342e29065 --- /dev/null +++ b/internal/data/otp.go @@ -0,0 +1,8 @@ +package data + +const ( + // TestnetAlwaysValidOTP is used for testing purposes and will be considered a valid OTP for any testnet account. + TestnetAlwaysValidOTP = "000000" + // TestnetAlwaysInvalidOTP is used for testing purposes and will be considered an invalid OTP for any testnet account. + TestnetAlwaysInvalidOTP = "999999" +) diff --git a/internal/data/payments.go b/internal/data/payments.go new file mode 100644 index 000000000..97033539c --- /dev/null +++ b/internal/data/payments.go @@ -0,0 +1,581 @@ +package data + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/stellar/go/support/log" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type Payment struct { + ID string `json:"id" db:"id"` + Amount string `json:"amount" db:"amount"` + StellarTransactionID string `json:"stellar_transaction_id" db:"stellar_transaction_id"` + // TODO: evaluate if we will keep or remove StellarOperationID + StellarOperationID string `json:"stellar_operation_id" db:"stellar_operation_id"` + Status PaymentStatus `json:"status" db:"status"` + StatusHistory PaymentStatusHistory `json:"status_history,omitempty" db:"status_history"` + Disbursement *Disbursement `json:"disbursement,omitempty" db:"disbursement"` + Asset Asset `json:"asset"` + ReceiverWallet *ReceiverWallet `json:"receiver_wallet,omitempty" db:"receiver_wallet"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type PaymentStatusHistoryEntry struct { + Status PaymentStatus `json:"status"` + StatusMessage string `json:"status_message"` + Timestamp time.Time `json:"timestamp"` +} + +type PaymentModel struct { + dbConnectionPool db.DBConnectionPool +} + +var ( + DefaultPaymentSortField = SortFieldUpdatedAt + DefaultPaymentSortOrder = SortOrderDESC + AllowedPaymentFilters = []FilterKey{FilterKeyStatus, FilterKeyCreatedAtAfter, FilterKeyCreatedAtBefore, FilterKeyReceiverID} + AllowedPaymentSorts = []SortField{SortFieldCreatedAt, SortFieldUpdatedAt} +) + +type PaymentInsert struct { + ReceiverID string `db:"receiver_id"` + DisbursementID string `db:"disbursement_id"` + Amount string `db:"amount"` + AssetID string `db:"asset_id"` + ReceiverWalletID string `db:"receiver_wallet_id"` +} + +type PaymentUpdate struct { + Status PaymentStatus `db:"status"` + StatusMessage string + StellarTransactionID string `db:"stellar_transaction_id"` +} + +type PaymentStatusHistory []PaymentStatusHistoryEntry + +// Value implements the driver.Valuer interface. +func (psh PaymentStatusHistory) Value() (driver.Value, error) { + var statusHistoryJSON []string + for _, sh := range psh { + shJSONBytes, err := json.Marshal(sh) + if err != nil { + return nil, fmt.Errorf("error converting status history to json for message: %w", err) + } + statusHistoryJSON = append(statusHistoryJSON, string(shJSONBytes)) + } + + return pq.Array(statusHistoryJSON).Value() +} + +// Scan implements the sql.Scanner interface. +func (psh *PaymentStatusHistory) Scan(src interface{}) error { + var statusHistoryJSON []string + if err := pq.Array(&statusHistoryJSON).Scan(src); err != nil { + return fmt.Errorf("error scanning status history value: %w", err) + } + + for _, sh := range statusHistoryJSON { + var shEntry PaymentStatusHistoryEntry + err := json.Unmarshal([]byte(sh), &shEntry) + if err != nil { + return fmt.Errorf("error unmarshaling status_history column: %w", err) + } + *psh = append(*psh, shEntry) + } + + return nil +} + +func (p *PaymentInsert) Validate() error { + if strings.TrimSpace(p.ReceiverID) == "" { + return fmt.Errorf("receiver_id is required") + } + + if strings.TrimSpace(p.DisbursementID) == "" { + return fmt.Errorf("disbursement_id is required") + } + + if err := utils.ValidateAmount(p.Amount); err != nil { + return fmt.Errorf("amount is invalid: %w", err) + } + + if strings.TrimSpace(p.AssetID) == "" { + return fmt.Errorf("asset_id is required") + } + + if strings.TrimSpace(p.ReceiverWalletID) == "" { + return fmt.Errorf("receiver_wallet_id is required") + } + + return nil +} + +func (p *PaymentUpdate) Validate() error { + if err := p.Status.Validate(); err != nil { + return fmt.Errorf("status is invalid: %w", err) + } + if strings.TrimSpace(p.StellarTransactionID) == "" { + return fmt.Errorf("stellar transaction id is required") + } + + return nil +} + +func (p *PaymentModel) Get(ctx context.Context, id string, sqlExec db.SQLExecuter) (*Payment, error) { + payment := Payment{} + + query := ` + SELECT + p.id, + p.amount, + COALESCE(p.stellar_transaction_id, '') as stellar_transaction_id, + COALESCE(p.stellar_operation_id, '') as stellar_operation_id, + p.status, + p.status_history, + p.created_at, + p.updated_at, + d.id as "disbursement.id", + d.name as "disbursement.name", + d.status as "disbursement.status", + d.created_at as "disbursement.created_at", + d.updated_at as "disbursement.updated_at", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + rw.id as "receiver_wallet.id", + COALESCE(rw.stellar_address, '') as "receiver_wallet.stellar_address", + COALESCE(rw.stellar_memo, '') as "receiver_wallet.stellar_memo", + COALESCE(rw.stellar_memo_type, '') as "receiver_wallet.stellar_memo_type", + rw.status as "receiver_wallet.status", + rw.created_at as "receiver_wallet.created_at", + rw.updated_at as "receiver_wallet.updated_at", + rw.receiver_id as "receiver_wallet.receiver.id", + w.id as "receiver_wallet.wallet.id", + w.name as "receiver_wallet.wallet.name" + FROM + payments p + JOIN disbursements d ON p.disbursement_id = d.id + JOIN assets a ON p.asset_id = a.id + JOIN receiver_wallets rw ON rw.receiver_id = p.receiver_id AND rw.wallet_id = d.wallet_id + JOIN wallets w ON rw.wallet_id = w.id + WHERE p.id = $1 + ` + + err := sqlExec.GetContext(ctx, &payment, query, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } else { + return nil, fmt.Errorf("error querying payment ID: %w", err) + } + } + + return &payment, nil +} + +// Count returns the number of payments matching the given query parameters. +func (p *PaymentModel) Count(ctx context.Context, queryParams *QueryParams, sqlExec db.SQLExecuter) (int, error) { + var count int + baseQuery := ` + SELECT + count(*) + FROM + payments p + JOIN disbursements d on p.disbursement_id = d.id + JOIN assets a on p.asset_id = a.id + JOIN wallets w on d.wallet_id = w.id + JOIN receiver_wallets rw on rw.receiver_id = p.receiver_id AND rw.wallet_id = w.id + ` + + query, params := newPaymentQuery(baseQuery, queryParams, false, sqlExec) + + err := sqlExec.GetContext(ctx, &count, query, params...) + if err != nil { + return 0, fmt.Errorf("error counting payments: %w", err) + } + return count, nil +} + +// GetAll returns all PAYMENTS matching the given query parameters. +func (p *PaymentModel) GetAll(ctx context.Context, queryParams *QueryParams, sqlExec db.SQLExecuter) ([]Payment, error) { + payments := []Payment{} + + query := ` + SELECT + p.id, + p.amount, + COALESCE(p.stellar_transaction_id, '') as stellar_transaction_id, + COALESCE(p.stellar_operation_id, '') as stellar_operation_id, + p.status, + p.status_history, + p.created_at, + p.updated_at, + d.id as "disbursement.id", + d.name as "disbursement.name", + d.status as "disbursement.status", + d.created_at as "disbursement.created_at", + d.updated_at as "disbursement.updated_at", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + rw.id as "receiver_wallet.id", + COALESCE(rw.stellar_address, '') as "receiver_wallet.stellar_address", + COALESCE(rw.stellar_memo, '') as "receiver_wallet.stellar_memo", + COALESCE(rw.stellar_memo_type, '') as "receiver_wallet.stellar_memo_type", + rw.status as "receiver_wallet.status", + rw.created_at as "receiver_wallet.created_at", + rw.updated_at as "receiver_wallet.updated_at", + rw.receiver_id as "receiver_wallet.receiver.id", + w.id as "receiver_wallet.wallet.id", + w.name as "receiver_wallet.wallet.name" + FROM + payments p + JOIN disbursements d on p.disbursement_id = d.id + JOIN assets a on p.asset_id = a.id + JOIN wallets w on d.wallet_id = w.id + JOIN receiver_wallets rw on rw.receiver_id = p.receiver_id AND rw.wallet_id = w.id + ` + + query, params := newPaymentQuery(query, queryParams, true, sqlExec) + + err := sqlExec.SelectContext(ctx, &payments, query, params...) + if err != nil { + return nil, fmt.Errorf("error querying payments: %w", err) + } + + return payments, nil +} + +// DeleteAllForDisbursement deletes all payments for a given disbursement. +func (p *PaymentModel) DeleteAllForDisbursement(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) error { + query := ` + DELETE FROM payments + WHERE disbursement_id = $1 + ` + + result, err := sqlExec.ExecContext(ctx, query, disbursementID) + if err != nil { + return fmt.Errorf("error deleting payments for disbursement: %w", err) + } + + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + log.Ctx(ctx).Infof("Deleted %d payments for disbursement %s", numRowsAffected, disbursementID) + + return nil +} + +// InsertAll inserts a batch of payments into the database. +func (p *PaymentModel) InsertAll(ctx context.Context, sqlExec db.SQLExecuter, inserts []PaymentInsert) error { + for _, payment := range inserts { + err := payment.Validate() + if err != nil { + return fmt.Errorf("error validating payment: %w", err) + } + } + query := ` + INSERT INTO payments ( + amount, + asset_id, + receiver_id, + disbursement_id, + receiver_wallet_id + ) VALUES ( + $1, + $2, + $3, + $4, + $5 + ) + ` + + for _, payment := range inserts { + _, err := sqlExec.ExecContext(ctx, query, payment.Amount, payment.AssetID, payment.ReceiverID, payment.DisbursementID, payment.ReceiverWalletID) + if err != nil { + return fmt.Errorf("error inserting payment: %w", err) + } + } + + return nil +} + +// UpdateStatusByDisbursementID updates the status of all payments with a given status for a given disbursement. +func (p *PaymentModel) UpdateStatusByDisbursementID(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string, targetStatus PaymentStatus) error { + sourceStatuses := targetStatus.SourceStatuses() + + query := ` + UPDATE payments + SET status = $1, + status_history = array_append(status_history, create_payment_status_history(NOW(), $1, NULL)) + WHERE disbursement_id = $2 + AND status = ANY($3) + ` + + result, err := sqlExec.ExecContext(ctx, query, targetStatus, disbursementID, pq.Array(sourceStatuses)) + if err != nil { + return fmt.Errorf("error updating payment statuses for disbursement %s: %w", disbursementID, err) + } + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + log.Ctx(ctx).Infof("Set %d payments for disbursement %s from %s to %s", numRowsAffected, disbursementID, sourceStatuses, targetStatus) + + return nil +} + +func (p *PaymentModel) GetBatchForUpdate(ctx context.Context, dbTx db.DBTransaction, batchSize int) ([]*Payment, error) { + if batchSize <= 0 { + return nil, fmt.Errorf("batch size must be greater than 0") + } + + query := ` + SELECT + p.id, + p.amount, + COALESCE(p.stellar_transaction_id, '') as "stellar_transaction_id", + COALESCE(p.stellar_operation_id, '') as "stellar_operation_id", + p.status, + p.created_at, + p.updated_at, + d.id as "disbursement.id", + d.status as "disbursement.status", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + rw.id as "receiver_wallet.id", + rw.receiver_id as "receiver_wallet.receiver.id", + COALESCE(rw.stellar_address, '') as "receiver_wallet.stellar_address", + COALESCE(rw.stellar_memo, '') as "receiver_wallet.stellar_memo", + COALESCE(rw.stellar_memo_type, '') as "receiver_wallet.stellar_memo_type", + rw.status as "receiver_wallet.status" + FROM + payments p + JOIN assets a on p.asset_id = a.id + JOIN receiver_wallets rw on p.receiver_wallet_id = rw.id + JOIN disbursements d on p.disbursement_id = d.id + WHERE p.status = $1 -- 'READY'::payment_status + AND rw.status = $2 -- 'REGISTERED'::receiver_wallet_status + AND d.status = $3 -- 'STARTED'::disbursement_status + ORDER BY p.disbursement_id ASC, p.updated_at ASC + LIMIT $4 + FOR UPDATE SKIP LOCKED + ` + + var payments []*Payment + err := dbTx.SelectContext(ctx, &payments, query, ReadyPaymentStatus, RegisteredReceiversWalletStatus, StartedDisbursementStatus, batchSize) + if err != nil { + return nil, fmt.Errorf("error getting ready payments: %w", err) + } + return payments, nil +} + +func (p *PaymentModel) UpdateStatuses(ctx context.Context, sqlExec db.SQLExecuter, payments []*Payment, toStatus PaymentStatus) error { + if len(payments) == 0 { + log.Ctx(ctx).Info("No payments to update") + return nil + } + // Validate transition + for _, payment := range payments { + if err := payment.Status.TransitionTo(toStatus); err != nil { + return fmt.Errorf("cannot transition from %s to %s for payment %s: %w", payment.Status, toStatus, payment.ID, err) + } + } + var paymentIDs []string + for _, payment := range payments { + paymentIDs = append(paymentIDs, payment.ID) + } + + query := ` + UPDATE payments + SET status = $1, + status_history = array_append(status_history, create_payment_status_history(NOW(), $1, NULL)) + WHERE id = ANY($2) + ` + + result, err := sqlExec.ExecContext(ctx, query, toStatus, pq.Array(paymentIDs)) + if err != nil { + return fmt.Errorf("error updating payment statuses: %w", err) + } + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + log.Ctx(ctx).Infof("Set %d payments to %s", numRowsAffected, toStatus) + + return nil +} + +// Update updates a payment's fields with the given update. +func (p *PaymentModel) Update(ctx context.Context, sqlExec db.SQLExecuter, payment *Payment, update *PaymentUpdate) error { + if err := update.Validate(); err != nil { + return fmt.Errorf("error validating payment update: %w", err) + } + + if err := payment.Status.TransitionTo(update.Status); err != nil { + return fmt.Errorf("cannot transition from %s to %s for payment %s: %w", payment.Status, update.Status, payment.ID, err) + } + + query := ` + UPDATE payments + SET status = $1, + status_history = array_append(status_history, create_payment_status_history(NOW(), $1, $2)), + stellar_transaction_id = COALESCE($3, stellar_transaction_id) + WHERE id = $4 + ` + + result, err := sqlExec.ExecContext(ctx, query, update.Status, update.StatusMessage, update.StellarTransactionID, payment.ID) + if err != nil { + return fmt.Errorf("error updating payment with id %s: %w", payment.ID, err) + } + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected for payment with id %s: %w", payment.ID, err) + } + if numRowsAffected == 0 { + return fmt.Errorf("payment %s status was not updated from %s to %s", payment.ID, payment.Status, update.Status) + } else if numRowsAffected == 1 { + log.Ctx(ctx).Infof("Set payment %s status from %s to %s", payment.ID, payment.Status, update.Status) + } else { + return fmt.Errorf("unexpected number of rows affected: %d when updating payment %s status from %s to %s", numRowsAffected, payment.ID, payment.Status, update.Status) + } + + return nil +} + +func (p *PaymentModel) RetryFailedPayments(ctx context.Context, email string, paymentIDs ...string) error { + return db.RunInTransaction(ctx, p.dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + if len(paymentIDs) == 0 { + return fmt.Errorf("payment ids is required: %w", ErrMissingInput) + } + + if email == "" { + return fmt.Errorf("user email is required: %w", ErrMissingInput) + } + + const query = ` + WITH previous_payments_stellar_transaction_ids AS ( + SELECT + id, + stellar_transaction_id, + $2 AS status_message + FROM + payments + WHERE + id = ANY($1) + AND status = 'FAILED'::payment_status + ) + UPDATE + payments + SET + status = 'READY'::payment_status, + stellar_transaction_id = '', + status_history = array_append(status_history, create_payment_status_history(NOW(), 'READY', CONCAT(pp.status_message, pp.stellar_transaction_id))) + FROM + previous_payments_stellar_transaction_ids pp + WHERE + payments.id = pp.id + ` + + statusMessage := fmt.Sprintf("User %s has requested to retry the payment - Previous Stellar Transaction ID: ", email) + + res, err := dbTx.ExecContext(ctx, query, pq.Array(paymentIDs), statusMessage) + if err != nil { + return fmt.Errorf("error retrying failed payments: %w", err) + } + + numRowsAffected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + if numRowsAffected != int64(len(paymentIDs)) { + return ErrMismatchNumRowsAffected + } + + return nil + }) +} + +// GetByIDs returns a list of payments for the given IDs. +func (p *PaymentModel) GetByIDs(ctx context.Context, sqlExec db.SQLExecuter, paymentIDs []string) ([]*Payment, error) { + payments := []*Payment{} + + if len(paymentIDs) == 0 { + return payments, nil + } + + query := ` + SELECT + p.id, + p.amount, + COALESCE(p.stellar_transaction_id, '') as "stellar_transaction_id", + COALESCE(p.stellar_operation_id, '') as "stellar_operation_id", + p.status, + p.created_at, + p.updated_at, + d.id as "disbursement.id", + d.status as "disbursement.status", + a.id as "asset.id", + a.code as "asset.code", + a.issuer as "asset.issuer", + rw.id as "receiver_wallet.id", + rw.receiver_id as "receiver_wallet.receiver.id", + COALESCE(rw.stellar_address, '') as "receiver_wallet.stellar_address", + COALESCE(rw.stellar_memo, '') as "receiver_wallet.stellar_memo", + COALESCE(rw.stellar_memo_type, '') as "receiver_wallet.stellar_memo_type", + rw.status as "receiver_wallet.status" + FROM + payments p + JOIN assets a on p.asset_id = a.id + JOIN receiver_wallets rw on p.receiver_wallet_id = rw.id + JOIN disbursements d on p.disbursement_id = d.id + WHERE p.id = ANY($1) + ` + + err := sqlExec.SelectContext(ctx, &payments, query, pq.Array(paymentIDs)) + if err != nil { + return nil, fmt.Errorf("error getting payments: %w", err) + } + return payments, nil +} + +// newPaymentQuery generates the full query and parameters for a payment search query +func newPaymentQuery(baseQuery string, queryParams *QueryParams, paginated bool, sqlExec db.SQLExecuter) (string, []interface{}) { + qb := NewQueryBuilder(baseQuery) + if queryParams.Filters[FilterKeyStatus] != nil { + qb.AddCondition("p.status = ?", queryParams.Filters[FilterKeyStatus]) + } + if queryParams.Filters[FilterKeyReceiverID] != nil { + qb.AddCondition("p.receiver_id = ?", queryParams.Filters[FilterKeyReceiverID]) + } + if queryParams.Filters[FilterKeyCreatedAtAfter] != nil { + qb.AddCondition("p.created_at >= ?", queryParams.Filters[FilterKeyCreatedAtAfter]) + } + if queryParams.Filters[FilterKeyCreatedAtBefore] != nil { + qb.AddCondition("p.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore]) + } + if paginated { + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "p") + qb.AddPagination(queryParams.Page, queryParams.PageLimit) + } + query, params := qb.Build() + return sqlExec.Rebind(query), params +} diff --git a/internal/data/payments_state_machine.go b/internal/data/payments_state_machine.go new file mode 100644 index 000000000..356bcb19f --- /dev/null +++ b/internal/data/payments_state_machine.go @@ -0,0 +1,69 @@ +package data + +import ( + "fmt" + "strings" +) + +type PaymentStatus string + +const ( + DraftPaymentStatus PaymentStatus = "DRAFT" + ReadyPaymentStatus PaymentStatus = "READY" + PendingPaymentStatus PaymentStatus = "PENDING" + PausedPaymentStatus PaymentStatus = "PAUSED" + SuccessPaymentStatus PaymentStatus = "SUCCESS" + FailedPaymentStatus PaymentStatus = "FAILED" +) + +// Validate validates the payment status +func (status PaymentStatus) Validate() error { + switch PaymentStatus(strings.ToUpper(string(status))) { + case DraftPaymentStatus, ReadyPaymentStatus, PendingPaymentStatus, + PausedPaymentStatus, SuccessPaymentStatus, FailedPaymentStatus: + return nil + default: + return fmt.Errorf("invalid payment status: %s", status) + } +} + +// TransitionTo transitions the payment status to the target state +func (status PaymentStatus) TransitionTo(targetState PaymentStatus) error { + return PaymentStateMachineWithInitialState(status).TransitionTo(targetState.State()) +} + +// PaymentStateMachineWithInitialState returns a state machine for Payments initialized with the given state +func PaymentStateMachineWithInitialState(initialState PaymentStatus) *StateMachine { + transitions := []StateTransition{ + {From: DraftPaymentStatus.State(), To: ReadyPaymentStatus.State()}, // disbursement started + {From: ReadyPaymentStatus.State(), To: PendingPaymentStatus.State()}, // payment gets submitted if user is ready + {From: ReadyPaymentStatus.State(), To: PausedPaymentStatus.State()}, // payment paused (when disbursement paused) + {From: PausedPaymentStatus.State(), To: ReadyPaymentStatus.State()}, // payment resumed (when disbursement resumed) + {From: PendingPaymentStatus.State(), To: FailedPaymentStatus.State()}, // payment fails + {From: FailedPaymentStatus.State(), To: PendingPaymentStatus.State()}, // payment retried + {From: PendingPaymentStatus.State(), To: SuccessPaymentStatus.State()}, // payment succeeds + } + + return NewStateMachine(initialState.State(), transitions) +} + +// PaymentStatuses returns a list of all possible payment statuses +func PaymentStatuses() []PaymentStatus { + return []PaymentStatus{DraftPaymentStatus, ReadyPaymentStatus, PendingPaymentStatus, PausedPaymentStatus, SuccessPaymentStatus, FailedPaymentStatus} +} + +// SourceStatuses returns a list of states that the payment status can transition from given the target state +func (status PaymentStatus) SourceStatuses() []PaymentStatus { + stateMachine := PaymentStateMachineWithInitialState(DraftPaymentStatus) + fromStates := []PaymentStatus{} + for _, fromState := range PaymentStatuses() { + if stateMachine.Transitions[fromState.State()][status.State()] { + fromStates = append(fromStates, fromState) + } + } + return fromStates +} + +func (status PaymentStatus) State() State { + return State(status) +} diff --git a/internal/data/payments_state_machine_test.go b/internal/data/payments_state_machine_test.go new file mode 100644 index 000000000..5e0e05067 --- /dev/null +++ b/internal/data/payments_state_machine_test.go @@ -0,0 +1,56 @@ +package data + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_PaymentStatus_SourceStatuses(t *testing.T) { + tests := []struct { + name string + targetStatus PaymentStatus + expectedSourceStatuses []PaymentStatus + }{ + { + name: "Draft", + targetStatus: DraftPaymentStatus, + expectedSourceStatuses: []PaymentStatus{}, + }, + { + name: "Ready", + targetStatus: ReadyPaymentStatus, + expectedSourceStatuses: []PaymentStatus{DraftPaymentStatus, PausedPaymentStatus}, + }, + { + name: "Pending", + targetStatus: PendingPaymentStatus, + expectedSourceStatuses: []PaymentStatus{ReadyPaymentStatus, FailedPaymentStatus}, + }, + { + name: "Paused", + targetStatus: PausedPaymentStatus, + expectedSourceStatuses: []PaymentStatus{ReadyPaymentStatus}, + }, + { + name: "Success", + targetStatus: SuccessPaymentStatus, + expectedSourceStatuses: []PaymentStatus{PendingPaymentStatus}, + }, + { + name: "Failure", + targetStatus: FailedPaymentStatus, + expectedSourceStatuses: []PaymentStatus{PendingPaymentStatus}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expectedSourceStatuses, tt.targetStatus.SourceStatuses()) + }) + } +} + +func Test_PaymentStatus_PaymentStatuses(t *testing.T) { + expectedStatuses := []PaymentStatus{DraftPaymentStatus, ReadyPaymentStatus, PendingPaymentStatus, PausedPaymentStatus, SuccessPaymentStatus, FailedPaymentStatus} + require.Equal(t, expectedStatuses, PaymentStatuses()) +} diff --git a/internal/data/payments_test.go b/internal/data/payments_test.go new file mode 100644 index 000000000..1f85fdc21 --- /dev/null +++ b/internal/data/payments_test.go @@ -0,0 +1,873 @@ +package data + +import ( + "context" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PaymentsModelGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, DraftReceiversWalletStatus) + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ + Name: "disbursement 1", + Status: DraftDisbursementStatus, + Asset: asset, + Wallet: wallet1, + Country: country, + CreatedAt: time.Date(2022, 3, 21, 23, 40, 20, 1431, time.UTC), + }) + + paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when payment does not exist", func(t *testing.T) { + _, err := paymentModel.Get(ctx, "invalid_id", dbConnectionPool) + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns payment successfully", func(t *testing.T) { + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + expected := CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: DraftPaymentStatus, + StatusHistory: []PaymentStatusHistoryEntry{ + { + Status: DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet1, + }) + actual, err := paymentModel.Get(ctx, expected.ID, dbConnectionPool) + require.NoError(t, err) + + assert.Equal(t, *expected, *actual) + }) + + t.Run("returns payment successfully receiver with multiple wallets", func(t *testing.T) { + wallet2 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet2", "https://www.wallet2.com", "www.wallet2.com", "wallet2://") + + receiverWallet2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, DraftReceiversWalletStatus) + + disbursement2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ + Name: "disbursement 2", + Status: DraftDisbursementStatus, + Asset: asset, + Wallet: wallet2, + Country: country, + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + expected := CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: DraftPaymentStatus, + StatusHistory: []PaymentStatusHistoryEntry{ + { + Status: DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet2, + }) + actual, err := paymentModel.Get(ctx, expected.ID, dbConnectionPool) + require.NoError(t, err) + + assert.Equal(t, *expected, *actual) + }) +} + +func Test_PaymentModelCount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ + Name: "disbursement 1", + Status: DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + disbursement2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ + Name: "disbursement 2", + Status: DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns 0 when no payments exist", func(t *testing.T) { + count, errPayment := paymentModel.Count(ctx, &QueryParams{}, dbConnectionPool) + require.NoError(t, errPayment) + assert.Equal(t, 0, count) + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: DraftPaymentStatus, + StatusHistory: []PaymentStatusHistoryEntry{ + { + Status: DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + Amount: "150", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: PendingPaymentStatus, + StatusHistory: []PaymentStatusHistoryEntry{ + { + Status: PendingPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + t.Run("returns count of payments", func(t *testing.T) { + count, err := paymentModel.Count(ctx, &QueryParams{}, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, 2, count) + }) + + t.Run("returns count of payments with filter", func(t *testing.T) { + filters := map[FilterKey]interface{}{ + FilterKeyStatus: DraftPaymentStatus, + } + count, err := paymentModel.Count(ctx, &QueryParams{Filters: filters}, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, 1, count) + }) +} + +func Test_PaymentModelGetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ + Name: "disbursement 1", + Status: DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + disbursement2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ + Name: "disbursement 2", + Status: DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns empty list when no payments exist", func(t *testing.T) { + payments, errPayment := paymentModel.GetAll(ctx, &QueryParams{}, dbConnectionPool) + require.NoError(t, errPayment) + assert.Equal(t, 0, len(payments)) + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + expectedPayment1 := CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: DraftPaymentStatus, + StatusHistory: []PaymentStatusHistoryEntry{ + { + Status: DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + expectedPayment2 := CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ + Amount: "150", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: PendingPaymentStatus, + StatusHistory: []PaymentStatusHistoryEntry{ + { + Status: DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + t.Run("returns payments successfully", func(t *testing.T) { + actualPayments, err := paymentModel.GetAll(ctx, &QueryParams{}, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, 2, len(actualPayments)) + assert.Equal(t, []Payment{*expectedPayment2, *expectedPayment1}, actualPayments) + }) + + t.Run("returns payments successfully with limit", func(t *testing.T) { + actualPayments, err := paymentModel.GetAll(ctx, &QueryParams{Page: 1, PageLimit: 1}, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, 1, len(actualPayments)) + assert.Equal(t, []Payment{*expectedPayment1}, actualPayments) + }) + + t.Run("returns payments successfully with offset", func(t *testing.T) { + actualPayments, err := paymentModel.GetAll(ctx, &QueryParams{Page: 2, PageLimit: 1}, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, 1, len(actualPayments)) + assert.Equal(t, []Payment{*expectedPayment2}, actualPayments) + }) + + t.Run("returns payments successfully with created at order", func(t *testing.T) { + actualPayments, err := paymentModel.GetAll(ctx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}, dbConnectionPool) + + require.NoError(t, err) + assert.Equal(t, 2, len(actualPayments)) + assert.Equal(t, []Payment{*expectedPayment1, *expectedPayment2}, actualPayments) + }) + + t.Run("returns payments successfully with updated at order", func(t *testing.T) { + actualPayments, err := paymentModel.GetAll(ctx, &QueryParams{SortBy: SortFieldUpdatedAt, SortOrder: SortOrderASC}, dbConnectionPool) + + require.NoError(t, err) + assert.Equal(t, 2, len(actualPayments)) + assert.Equal(t, []Payment{*expectedPayment1, *expectedPayment2}, actualPayments) + }) + + t.Run("returns payments successfully with filter", func(t *testing.T) { + filters := map[FilterKey]interface{}{ + FilterKeyStatus: PendingPaymentStatus, + } + actualPayments, err := paymentModel.GetAll(ctx, &QueryParams{Filters: filters}, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, 1, len(actualPayments)) + assert.Equal(t, []Payment{*expectedPayment2}, actualPayments) + }) + + t.Run("should not return duplicated entries when receiver are in more than one disbursements with different wallets", func(t *testing.T) { + models, err := NewModels(dbConnectionPool) + require.NoError(t, err) + + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + usdc := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + demoWallet := CreateWalletFixture(t, ctx, dbConnectionPool, "Demo Wallet", "https://demo-wallet.stellar.org", "https://demo-wallet.stellar.org", "demo-wallet-server.stellar.org") + vibrantWallet := CreateWalletFixture(t, ctx, dbConnectionPool, "Vibrant Assist", "https://vibrantapp.com", "api-dev.vibrantapp.com", "https://vibrantapp.com/sdp-dev") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverDemoWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, demoWallet.ID, ReadyReceiversWalletStatus) + receiverVibrantWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, vibrantWallet.ID, ReadyReceiversWalletStatus) + + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Name: "disbursement 1", + Status: ReadyDisbursementStatus, + Asset: usdc, + Wallet: demoWallet, + Country: country, + }) + + disbursement2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Name: "disbursement 2", + Status: ReadyDisbursementStatus, + Asset: usdc, + Wallet: vibrantWallet, + Country: country, + }) + + demoWalletPayment := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "100", + Status: ReadyPaymentStatus, + Disbursement: disbursement1, + Asset: *usdc, + ReceiverWallet: receiverDemoWallet, + }) + + vibrantWalletPayment := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "100", + Status: ReadyPaymentStatus, + Disbursement: disbursement2, + Asset: *usdc, + ReceiverWallet: receiverVibrantWallet, + }) + + payments, err := models.Payment.GetAll(ctx, &QueryParams{ + Filters: map[FilterKey]interface{}{ + FilterKeyReceiverID: receiver.ID, + }, + }, dbConnectionPool) + require.NoError(t, err) + + assert.Len(t, payments, 2) + assert.Equal(t, []Payment{ + *demoWalletPayment, + *vibrantWalletPayment, + }, payments) + }) +} + +// func Test_PaymentsModelGetByIDs(t *testing.T) { +// dbt := dbtest.Open(t) +// defer dbt.Close() + +// dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) +// require.NoError(t, err) +// defer dbConnectionPool.Close() + +// ctx := context.Background() + +// asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") +// country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") +// wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + +// receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) +// receiverWallet1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, DraftReceiversWalletStatus) + +// receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) +// receiverWallet2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet1.ID, DraftReceiversWalletStatus) + +// disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} +// disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &Disbursement{ +// Name: "disbursement 1", +// Status: Draft, +// Asset: asset, +// Wallet: wallet1, +// Country: country, +// CreatedAt: time.Date(2022, 3, 21, 23, 40, 20, 1431, time.UTC), +// }) + +// paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + +// t.Run("returns empty list when payments ids are not found", func(t *testing.T) { +// payments, err := paymentModel.GetByIDs(ctx, dbConnectionPool, []string{"invalid_id"}) +// require.NoError(t, err) +// require.Empty(t, payments) +// }) + +// t.Run("returns payments successfully", func(t *testing.T) { +// stellarTransactionID, err := utils.RandomString(64) +// require.NoError(t, err) +// stellarOperationID, err := utils.RandomString(32) +// require.NoError(t, err) + +// payment1 := CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ +// Amount: "50", +// StellarTransactionID: stellarTransactionID, +// StellarOperationID: stellarOperationID, +// Status: DraftPaymentStatus, +// StatusHistory: []PaymentStatusHistoryEntry{ +// { +// Status: DraftPaymentStatus, +// StatusMessage: "", +// Timestamp: time.Now(), +// }, +// }, +// Disbursement: disbursement1, +// Asset: *asset, +// ReceiverWallet: receiverWallet1, +// }) + +// stellarTransactionID, err = utils.RandomString(64) +// require.NoError(t, err) +// stellarOperationID, err = utils.RandomString(32) +// require.NoError(t, err) + +// payment2 := CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &Payment{ +// Amount: "150", +// StellarTransactionID: stellarTransactionID, +// StellarOperationID: stellarOperationID, +// Status: DraftPaymentStatus, +// StatusHistory: []PaymentStatusHistoryEntry{ +// { +// Status: DraftPaymentStatus, +// StatusMessage: "", +// Timestamp: time.Now(), +// }, +// }, +// Disbursement: disbursement1, +// Asset: *asset, +// ReceiverWallet: receiverWallet2, +// }) +// actual, err := paymentModel.GetByIDs(ctx, dbConnectionPool, []string{payment1.ID, payment2.ID}) +// require.NoError(t, err) + +// p1 := Payment{ +// ID: payment1.ID, +// Amount: payment1.Amount, +// StellarTransactionID: payment1.StellarTransactionID, +// StellarOperationID: payment1.StellarOperationID, +// Status: payment1.Status, +// CreatedAt: payment1.CreatedAt, +// UpdatedAt: payment1.UpdatedAt, +// Disbursement: &Disbursement{ +// ID: payment1.Disbursement.ID, +// Status: payment1.Disbursement.Status, +// }, +// Asset: Asset{ +// ID: payment1.Asset.ID, +// Code: payment1.Asset.Code, +// Issuer: payment1.Asset.Issuer, +// }, +// ReceiverWallet: &ReceiverWallet{ +// ID: payment1.ReceiverWallet.ID, +// StellarAddress: payment1.ReceiverWallet.StellarAddress, +// StellarMemo: payment1.ReceiverWallet.StellarMemo, +// StellarMemoType: payment1.ReceiverWallet.StellarMemoType, +// Status: payment1.ReceiverWallet.Status, +// Receiver: Receiver{ +// ID: payment1.ReceiverWallet.Receiver.ID, +// }, +// }, +// } + +// p2 := Payment{ +// ID: payment2.ID, +// Amount: payment2.Amount, +// StellarTransactionID: payment2.StellarTransactionID, +// StellarOperationID: payment2.StellarOperationID, +// Status: payment2.Status, +// CreatedAt: payment2.CreatedAt, +// UpdatedAt: payment2.UpdatedAt, +// Disbursement: &Disbursement{ +// ID: payment2.Disbursement.ID, +// Status: payment2.Disbursement.Status, +// }, +// Asset: Asset{ +// ID: payment2.Asset.ID, +// Code: payment2.Asset.Code, +// Issuer: payment2.Asset.Issuer, +// }, +// ReceiverWallet: &ReceiverWallet{ +// ID: payment2.ReceiverWallet.ID, +// StellarAddress: payment2.ReceiverWallet.StellarAddress, +// StellarMemo: payment2.ReceiverWallet.StellarMemo, +// StellarMemoType: payment2.ReceiverWallet.StellarMemoType, +// Status: payment2.ReceiverWallet.Status, +// Receiver: Receiver{ +// ID: payment2.ReceiverWallet.Receiver.ID, +// }, +// }, +// } + +// payments := []*Payment{&p1, &p2} +// assert.Equal(t, payments, actual) +// }) +// } + +func Test_PaymentNewPaymentQuery(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + testCases := []struct { + name string + baseQuery string + queryParams QueryParams + paginated bool + expectedQuery string + expectedParams []interface{} + }{ + { + name: "build payment query without params and pagination", + baseQuery: "SELECT * FROM payments p", + queryParams: QueryParams{}, + paginated: false, + expectedQuery: "SELECT * FROM payments p", + expectedParams: []interface{}{}, + }, + { + name: "build payment query with status filter", + baseQuery: "SELECT * FROM payments p", + queryParams: QueryParams{ + Filters: map[FilterKey]interface{}{ + FilterKeyStatus: "draft", + }, + }, + paginated: false, + expectedQuery: "SELECT * FROM payments p WHERE 1=1 AND p.status = $1", + expectedParams: []interface{}{"draft"}, + }, + { + name: "build payment query with receiver_id filter", + baseQuery: "SELECT * FROM payments p", + queryParams: QueryParams{ + Filters: map[FilterKey]interface{}{ + FilterKeyReceiverID: "receiver_id", + }, + }, + paginated: false, + expectedQuery: "SELECT * FROM payments p WHERE 1=1 AND p.receiver_id = $1", + expectedParams: []interface{}{"receiver_id"}, + }, + { + name: "build payment query with created_at filters", + baseQuery: "SELECT * FROM payments p", + queryParams: QueryParams{ + Filters: map[FilterKey]interface{}{ + FilterKeyCreatedAtAfter: "00-01-01", + FilterKeyCreatedAtBefore: "00-01-31", + }, + }, + paginated: false, + expectedQuery: "SELECT * FROM payments p WHERE 1=1 AND p.created_at >= $1 AND p.created_at <= $2", + expectedParams: []interface{}{"00-01-01", "00-01-31"}, + }, + { + name: "build payment query with pagination", + baseQuery: "SELECT * FROM payments p", + queryParams: QueryParams{ + Page: 1, + PageLimit: 20, + SortBy: "created_at", + SortOrder: "ASC", + }, + paginated: true, + expectedQuery: "SELECT * FROM payments p ORDER BY p.created_at ASC LIMIT $1 OFFSET $2", + expectedParams: []interface{}{20, 0}, + }, + { + name: "build payment query with all filters and pagination", + baseQuery: "SELECT * FROM payments p", + queryParams: QueryParams{ + Page: 1, + PageLimit: 20, + SortBy: "created_at", + SortOrder: "ASC", + Filters: map[FilterKey]interface{}{ + FilterKeyStatus: "draft", + FilterKeyReceiverID: "receiver_id", + FilterKeyCreatedAtAfter: "00-01-01", + FilterKeyCreatedAtBefore: "00-01-31", + }, + }, + paginated: true, + expectedQuery: "SELECT * FROM payments p WHERE 1=1 AND p.status = $1 AND p.receiver_id = $2 AND p.created_at >= $3 AND p.created_at <= $4 ORDER BY p.created_at ASC LIMIT $5 OFFSET $6", + expectedParams: []interface{}{"draft", "receiver_id", "00-01-01", "00-01-31", 20, 0}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, params := newPaymentQuery(tc.baseQuery, &tc.queryParams, tc.paginated, dbConnectionPool) + + assert.Equal(t, tc.expectedQuery, query) + assert.Equal(t, tc.expectedParams, params) + }) + } +} + +func Test_PaymentModelRetryFailedPayments(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + models, err := NewModels(dbConnectionPool) + require.NoError(t, err) + + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + country := CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, ReadyReceiversWalletStatus) + + disbursement := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: wallet, + Asset: asset, + Status: ReadyDisbursementStatus, + VerificationField: VerificationFieldDateOfBirth, + }) + + t.Run("does not update payments when no payments IDs is given", func(t *testing.T) { + err := models.Payment.RetryFailedPayments(ctx, "user@test.com") + assert.ErrorIs(t, err, ErrMissingInput) + }) + + t.Run("does not update payments when email is empty", func(t *testing.T) { + err := models.Payment.RetryFailedPayments(ctx, "", "payment-id") + assert.ErrorIs(t, err, ErrMissingInput) + }) + + t.Run("returns error when no rows is affected", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + + payment1 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: PendingPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment2 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: ReadyPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + err := models.Payment.RetryFailedPayments(ctx, "user@test.com", payment1.ID, payment2.ID) + assert.ErrorIs(t, err, ErrMismatchNumRowsAffected) + + payment1DB, err := models.Payment.Get(ctx, payment1.ID, dbConnectionPool) + require.NoError(t, err) + + payment2DB, err := models.Payment.Get(ctx, payment2.ID, dbConnectionPool) + require.NoError(t, err) + + // Payment 1 + assert.Equal(t, PendingPaymentStatus, payment1DB.Status) + assert.Equal(t, payment1.StellarTransactionID, payment1DB.StellarTransactionID) + assert.Equal(t, payment1.StatusHistory, payment1DB.StatusHistory) + + // Payment 2 + assert.Equal(t, ReadyPaymentStatus, payment2DB.Status) + assert.Equal(t, payment2.StellarTransactionID, payment2DB.StellarTransactionID) + assert.Equal(t, payment2.StatusHistory, payment2DB.StatusHistory) + }) + + t.Run("returns error when the number of affected rows is different from the length of payment IDs", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + + payment1 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: PendingPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment2 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: ReadyPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment3 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-3", + StellarOperationID: "operation-id-3", + Status: FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + err := models.Payment.RetryFailedPayments(ctx, "user@test.com", payment1.ID, payment2.ID, payment3.ID) + assert.ErrorIs(t, err, ErrMismatchNumRowsAffected) + + payment1DB, err := models.Payment.Get(ctx, payment1.ID, dbConnectionPool) + require.NoError(t, err) + + payment2DB, err := models.Payment.Get(ctx, payment2.ID, dbConnectionPool) + require.NoError(t, err) + + payment3DB, err := models.Payment.Get(ctx, payment3.ID, dbConnectionPool) + require.NoError(t, err) + + // Payment 1 + assert.Equal(t, PendingPaymentStatus, payment1DB.Status) + assert.Equal(t, payment1.StellarTransactionID, payment1DB.StellarTransactionID) + assert.Equal(t, payment1.StatusHistory, payment1DB.StatusHistory) + + // Payment 2 + assert.Equal(t, ReadyPaymentStatus, payment2DB.Status) + assert.Equal(t, payment2.StellarTransactionID, payment2DB.StellarTransactionID) + assert.Equal(t, payment2.StatusHistory, payment2DB.StatusHistory) + + // Payment 3 + assert.Equal(t, FailedPaymentStatus, payment3DB.Status) + assert.Equal(t, payment3.StellarTransactionID, payment3DB.StellarTransactionID) + assert.Equal(t, payment3.StatusHistory, payment3DB.StatusHistory) + }) + + t.Run("successfully updates failed payments", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + + payment1 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment2 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + err := models.Payment.RetryFailedPayments(ctx, "user@test.com", payment1.ID, payment2.ID) + require.NoError(t, err) + + payment1DB, err := models.Payment.Get(ctx, payment1.ID, dbConnectionPool) + require.NoError(t, err) + + payment2DB, err := models.Payment.Get(ctx, payment2.ID, dbConnectionPool) + require.NoError(t, err) + + // Payment 1 + assert.Equal(t, ReadyPaymentStatus, payment1DB.Status) + assert.Empty(t, payment1DB.StellarTransactionID) + assert.NotEqual(t, payment1.StatusHistory, payment1DB.StatusHistory) + assert.Len(t, payment1DB.StatusHistory, 2) + assert.Equal(t, ReadyPaymentStatus, payment1DB.StatusHistory[1].Status) + assert.Equal(t, "User user@test.com has requested to retry the payment - Previous Stellar Transaction ID: stellar-transaction-id-1", payment1DB.StatusHistory[1].StatusMessage) + + // Payment 2 + assert.Equal(t, ReadyPaymentStatus, payment2DB.Status) + assert.Empty(t, payment2DB.StellarTransactionID) + assert.NotEqual(t, payment2.StatusHistory, payment2DB.StatusHistory) + assert.Len(t, payment2DB.StatusHistory, 2) + assert.Equal(t, ReadyPaymentStatus, payment2DB.StatusHistory[1].Status) + assert.Equal(t, "User user@test.com has requested to retry the payment - Previous Stellar Transaction ID: stellar-transaction-id-2", payment2DB.StatusHistory[1].StatusMessage) + }) +} diff --git a/internal/data/query_builder.go b/internal/data/query_builder.go new file mode 100644 index 000000000..7227b208f --- /dev/null +++ b/internal/data/query_builder.go @@ -0,0 +1,70 @@ +package data + +import ( + "fmt" +) + +// QueryBuilder is a helper struct for building SQL queries +type QueryBuilder struct { + baseQuery string + whereClause string + whereParams []interface{} + sortClause string + paginationClause string + paginationParams []interface{} +} + +// NewQueryBuilder creates a new QueryBuilder +func NewQueryBuilder(query string) *QueryBuilder { + return &QueryBuilder{ + baseQuery: query, + } +} + +// AddCondition adds a condition to the query +// If the value is nil or empty, the condition is not added +// The condition should be a string with a placeholder for the value e.g. "name = ?", "id > ?" +func (qb *QueryBuilder) AddCondition(condition string, value ...interface{}) *QueryBuilder { + if len(value) > 0 { + qb.whereClause = fmt.Sprintf("%s %s", qb.whereClause, "AND "+condition) + qb.whereParams = append(qb.whereParams, value...) + } + return qb +} + +// AddSorting adds a sorting clause to the query +// prefix is the prefix to use for the sort field e.g. "d" for "d.created_at" +func (qb *QueryBuilder) AddSorting(sortField SortField, sortOrder SortOrder, prefix string) *QueryBuilder { + if sortField != "" { + qb.sortClause = fmt.Sprintf("ORDER BY %s.%s %s", prefix, sortField, sortOrder) + } + return qb +} + +// AddPagination adds a pagination clause to the query +func (qb *QueryBuilder) AddPagination(page int, pageLimit int) *QueryBuilder { + if page > 0 && pageLimit > 0 { + offset := (page - 1) * pageLimit + qb.paginationClause = "LIMIT ? OFFSET ?" + qb.paginationParams = append(qb.paginationParams, pageLimit, offset) + } + return qb +} + +// Build assembles all statements in the correct order and returns the query and the parameters +func (qb *QueryBuilder) Build() (string, []interface{}) { + query := qb.baseQuery + params := []interface{}{} + if qb.whereClause != "" { + query = fmt.Sprintf("%s WHERE 1=1%s", query, qb.whereClause) + params = append(params, qb.whereParams...) + } + if qb.sortClause != "" { + query = fmt.Sprintf("%s %s", query, qb.sortClause) + } + if qb.paginationClause != "" { + query = fmt.Sprintf("%s %s", query, qb.paginationClause) + params = append(params, qb.paginationParams...) + } + return query, params +} diff --git a/internal/data/query_builder_test.go b/internal/data/query_builder_test.go new file mode 100644 index 000000000..346c375ca --- /dev/null +++ b/internal/data/query_builder_test.go @@ -0,0 +1,66 @@ +package data + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_QueryBuilder(t *testing.T) { + t.Run("Test AddCondition", func(t *testing.T) { + qb := NewQueryBuilder("SELECT * FROM disbursements") + + qb.AddCondition("name = ?", "Disbursement 1") + actual, params := qb.Build() + + expectedQuery := "SELECT * FROM disbursements WHERE 1=1 AND name = ?" + + assert.Equal(t, expectedQuery, actual) + assert.Equal(t, []interface{}{"Disbursement 1"}, params) + }) + + t.Run("Test AddCondition multiple params", func(t *testing.T) { + qb := NewQueryBuilder("SELECT * FROM receivers") + + qb.AddCondition("(id ILIKE ? OR email ILIKE ? OR phone_number ILIKE ?)", "id", "mock@email.com", "+9999999") + actual, params := qb.Build() + + expectedQuery := "SELECT * FROM receivers WHERE 1=1 AND (id ILIKE ? OR email ILIKE ? OR phone_number ILIKE ?)" + + assert.Equal(t, expectedQuery, actual) + assert.Equal(t, []interface{}{"id", "mock@email.com", "+9999999"}, params) + }) + + t.Run("Test AddSorting", func(t *testing.T) { + qb := NewQueryBuilder("SELECT * FROM disbursements d") + + qb.AddSorting("created_at", "DESC", "d") + actual, _ := qb.Build() + + expectedQuery := "SELECT * FROM disbursements d ORDER BY d.created_at DESC" + assert.Equal(t, expectedQuery, actual) + }) + + t.Run("Test AddPagination", func(t *testing.T) { + qb := NewQueryBuilder("SELECT * FROM disbursements d") + + qb.AddPagination(2, 20) + actual, params := qb.Build() + + expectedQuery := "SELECT * FROM disbursements d LIMIT ? OFFSET ?" + assert.Equal(t, expectedQuery, actual) + assert.Equal(t, []interface{}{20, 20}, params) + }) + + t.Run("Test Full query", func(t *testing.T) { + qb := NewQueryBuilder("SELECT * FROM disbursements d") + qb.AddCondition("name = ?", "Disbursement 1") + qb.AddSorting("created_at", "DESC", "d") + qb.AddPagination(2, 20) + actual, params := qb.Build() + + expectedQuery := "SELECT * FROM disbursements d WHERE 1=1 AND name = ? ORDER BY d.created_at DESC LIMIT ? OFFSET ?" + assert.Equal(t, expectedQuery, actual) + assert.Equal(t, []interface{}{"Disbursement 1", 20, 20}, params) + }) +} diff --git a/internal/data/query_params.go b/internal/data/query_params.go new file mode 100644 index 000000000..c24cd4811 --- /dev/null +++ b/internal/data/query_params.go @@ -0,0 +1,34 @@ +package data + +type QueryParams struct { + Query string + Page int + PageLimit int + SortBy SortField + SortOrder SortOrder + Filters map[FilterKey]interface{} +} + +type SortOrder string + +const ( + SortOrderASC SortOrder = "ASC" + SortOrderDESC SortOrder = "DESC" +) + +type SortField string + +const ( + SortFieldName SortField = "name" + SortFieldCreatedAt SortField = "created_at" + SortFieldUpdatedAt SortField = "updated_at" +) + +type FilterKey string + +const ( + FilterKeyStatus FilterKey = "status" + FilterKeyReceiverID FilterKey = "receiver_id" + FilterKeyCreatedAtAfter FilterKey = "created_at_after" + FilterKeyCreatedAtBefore FilterKey = "created_at_before" +) diff --git a/internal/data/receiver_verification.go b/internal/data/receiver_verification.go new file mode 100644 index 000000000..06fbb3f1e --- /dev/null +++ b/internal/data/receiver_verification.go @@ -0,0 +1,181 @@ +package data + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/stellar/go/support/log" + + "golang.org/x/crypto/bcrypt" + + "github.com/lib/pq" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type ReceiverVerification struct { + ReceiverID string `db:"receiver_id"` + VerificationField VerificationField `db:"verification_field"` + HashedValue string `db:"hashed_value"` + Attempts int `db:"attempts"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + ConfirmedAt *time.Time `db:"confirmed_at"` + FailedAt *time.Time `db:"failed_at"` +} + +type ReceiverVerificationModel struct{} + +type ReceiverVerificationInsert struct { + ReceiverID string `db:"receiver_id"` + VerificationField VerificationField `db:"verification_field"` + VerificationValue string `db:"hashed_value"` +} + +const MaxAttemptsAllowed = 6 + +func (rvi *ReceiverVerificationInsert) Validate() error { + if strings.TrimSpace(rvi.ReceiverID) == "" { + return fmt.Errorf("receiver id is required") + } + if rvi.VerificationField == "" { + return fmt.Errorf("verification field is required") + } + if strings.TrimSpace(rvi.VerificationValue) == "" { + return fmt.Errorf("verification value is required") + } + return nil +} + +// GetByReceiverIdsAndVerificationField returns receiver verifications by receiver ids and verification type +func (m ReceiverVerificationModel) GetByReceiverIdsAndVerificationField(ctx context.Context, sqlExec db.SQLExecuter, receiverIds []string, verificationField VerificationField) ([]*ReceiverVerification, error) { + receiverVerifications := []*ReceiverVerification{} + query := ` + SELECT + receiver_id, + verification_field, + hashed_value, + attempts, + created_at, + updated_at, + confirmed_at, + failed_at + FROM + receiver_verifications + WHERE + receiver_id = ANY($1) AND + verification_field = $2 + ` + err := sqlExec.SelectContext(ctx, &receiverVerifications, query, pq.Array(receiverIds), verificationField) + if err != nil { + return nil, fmt.Errorf("error querying receiver verifications: %w", err) + } + return receiverVerifications, nil +} + +// Insert inserts a new receiver verification +func (m ReceiverVerificationModel) Insert(ctx context.Context, sqlExec db.SQLExecuter, verificationInsert ReceiverVerificationInsert) (string, error) { + err := verificationInsert.Validate() + if err != nil { + return "", fmt.Errorf("error validating receiver verification insert: %w", err) + } + hashedValue, err := HashVerificationValue(verificationInsert.VerificationValue) + if err != nil { + return "", fmt.Errorf("error hashing verification value: %w", err) + } + + query := ` + INSERT INTO receiver_verifications ( + receiver_id, + verification_field, + hashed_value + ) VALUES ($1, $2, $3) + ` + + _, err = sqlExec.ExecContext(ctx, query, verificationInsert.ReceiverID, verificationInsert.VerificationField, hashedValue) + + if err != nil { + return "", fmt.Errorf("error inserting receiver verification: %w", err) + } + + return hashedValue, nil +} + +// UpdateVerificationValue updates the hashed value of a receiver verification. +func (m ReceiverVerificationModel) UpdateVerificationValue(ctx context.Context, + sqlExec db.SQLExecuter, + receiverID string, + verificationField VerificationField, + verificationValue string, +) error { + log.Ctx(ctx).Infof("Calling UpdateVerificationValue for receiver %s and verification field %s", receiverID, verificationField) + hashedValue, err := HashVerificationValue(verificationValue) + if err != nil { + return fmt.Errorf("error hashing verification value: %w", err) + } + + query := ` + UPDATE receiver_verifications + SET hashed_value = $1 + WHERE receiver_id = $2 AND verification_field = $3 + ` + + _, err = sqlExec.ExecContext(ctx, query, hashedValue, receiverID, verificationField) + + if err != nil { + return fmt.Errorf("error updating receiver verification: %w", err) + } + + return nil +} + +// UpdateVerificationValue updates the hashed value of a receiver verification. +func (m ReceiverVerificationModel) UpdateReceiverVerification(ctx context.Context, receiverVerification ReceiverVerification, sqlExec db.SQLExecuter) error { + query := ` + UPDATE + receiver_verifications + SET + attempts = $1, + confirmed_at = $2, + failed_at = $3 + WHERE + receiver_id = $4 AND verification_field = $5 + ` + + _, err := sqlExec.ExecContext(ctx, + query, + receiverVerification.Attempts, + receiverVerification.ConfirmedAt, + receiverVerification.FailedAt, + receiverVerification.ReceiverID, + receiverVerification.VerificationField, + ) + if err != nil { + return fmt.Errorf("error updating receiver verification: %w", err) + } + + return nil +} + +// ExceededAttempts check if the number of attempts exceeded the max value. +func (*ReceiverVerificationModel) ExceededAttempts(attempts int) bool { + return attempts >= MaxAttemptsAllowed +} + +func HashVerificationValue(verificationValue string) (string, error) { + hashedValue, err := bcrypt.GenerateFromPassword([]byte(verificationValue), bcrypt.MinCost) + if err != nil { + return "", fmt.Errorf("error hashing verification value: %w", err) + } + return string(hashedValue), nil +} + +func CompareVerificationValue(hashedValue, verificationValue string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hashedValue), []byte(verificationValue)) + if err != nil { + return false + } + return err == nil +} diff --git a/internal/data/receiver_verification_test.go b/internal/data/receiver_verification_test.go new file mode 100644 index 000000000..3aa078fe3 --- /dev/null +++ b/internal/data/receiver_verification_test.go @@ -0,0 +1,203 @@ +package data + +import ( + "context" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReceiverVerificationModel_GetByReceiverIdsAndVerificationField(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver3 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + verification1 := CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiver1.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + verification2 := CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiver2.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: "1990-01-02", + }) + CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiver3.ID, + VerificationField: VerificationFieldPin, + VerificationValue: "1990-01-03", + }) + + verifiedReceivers := []string{receiver1.ID, receiver2.ID} + verifieldValues := []string{verification1.HashedValue, verification2.HashedValue} + + receiverVerificationModel := ReceiverVerificationModel{} + + actualVerifications, err := receiverVerificationModel.GetByReceiverIdsAndVerificationField(ctx, dbConnectionPool, []string{receiver1.ID, receiver2.ID, receiver3.ID}, VerificationFieldDateOfBirth) + require.NoError(t, err) + assert.Equal(t, 2, len(actualVerifications)) + for _, v := range actualVerifications { + assert.Equal(t, VerificationFieldDateOfBirth, v.VerificationField) + assert.Contains(t, verifiedReceivers, v.ReceiverID) + assert.Contains(t, verifieldValues, v.HashedValue) + } +} + +func Test_ReceiverVerificationModel_Insert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + receiverVerificationModel := ReceiverVerificationModel{} + + verification := ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + } + + _, err = receiverVerificationModel.Insert(ctx, dbConnectionPool, verification) + require.NoError(t, err) + + actualVerification, err := receiverVerificationModel.GetByReceiverIdsAndVerificationField(ctx, dbConnectionPool, []string{receiver.ID}, VerificationFieldDateOfBirth) + require.NoError(t, err) + verified := CompareVerificationValue(actualVerification[0].HashedValue, verification.VerificationValue) + assert.True(t, verified) + assert.Equal(t, verification.ReceiverID, actualVerification[0].ReceiverID) + assert.Equal(t, verification.VerificationField, actualVerification[0].VerificationField) +} + +func Test_ReceiverVerificationModel_UpdateVerificationValue(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + receiverVerificationModel := ReceiverVerificationModel{} + + oldExpectedValue := "1990-01-01" + actualBeforeUpdate, err := receiverVerificationModel.Insert(ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: oldExpectedValue, + }) + require.NoError(t, err) + assert.NotEmpty(t, actualBeforeUpdate) + verified := CompareVerificationValue(actualBeforeUpdate, oldExpectedValue) + assert.True(t, verified) + + newExpectedValue := "1990-01-02" + err = receiverVerificationModel.UpdateVerificationValue(ctx, dbConnectionPool, receiver.ID, VerificationFieldDateOfBirth, newExpectedValue) + require.NoError(t, err) + + actualAfterUpdate, err := receiverVerificationModel.GetByReceiverIdsAndVerificationField(ctx, dbConnectionPool, []string{receiver.ID}, VerificationFieldDateOfBirth) + require.NoError(t, err) + verified = CompareVerificationValue(actualAfterUpdate[0].HashedValue, newExpectedValue) + assert.True(t, verified) +} + +func Test_ReceiverVerificationModel_UpdateReceiverVerification(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverVerificationModel := ReceiverVerificationModel{} + + verification := CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + assert.Empty(t, verification.ConfirmedAt) + assert.Empty(t, verification.FailedAt) + assert.Equal(t, 0, verification.Attempts) + + date := time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC) + verification.Attempts = 5 + verification.ConfirmedAt = &date + verification.FailedAt = &date + + err = receiverVerificationModel.UpdateReceiverVerification(ctx, *verification, dbConnectionPool) + require.NoError(t, err) + + // validate if the receiver verification has been updated + query := ` + SELECT + rv.attempts, + rv.confirmed_at, + rv.failed_at + FROM + receiver_verifications rv + WHERE + rv.receiver_id = $1 AND rv.verification_field = $2 + ` + receiverVerificationUpdated := ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &receiverVerificationUpdated, query, verification.ReceiverID, verification.VerificationField) + require.NoError(t, err) + + assert.Equal(t, &date, receiverVerificationUpdated.ConfirmedAt) + assert.Equal(t, &date, receiverVerificationUpdated.FailedAt) + assert.Equal(t, 5, receiverVerificationUpdated.Attempts) +} + +func Test_ReceiverVerificationModel_CheckTotalAttempts(t *testing.T) { + receiverVerificationModel := &ReceiverVerificationModel{} + + t.Run("attempts exceeded the max value", func(t *testing.T) { + attempts := 6 + e := receiverVerificationModel.ExceededAttempts(attempts) + assert.True(t, e) + }) + + t.Run("attempts have not exceeded the max value", func(t *testing.T) { + attempts := 1 + e := receiverVerificationModel.ExceededAttempts(attempts) + assert.False(t, e) + }) +} + +func Test_ReceiverVerification_HashAndCompareVerificationValue(t *testing.T) { + verificationValue := "1987-01-01" + hashedVerificationInfo, err := HashVerificationValue(verificationValue) + require.NoError(t, err) + assert.NotEmpty(t, hashedVerificationInfo) + + assert.NotEqual(t, verificationValue, hashedVerificationInfo) + assert.Len(t, hashedVerificationInfo, 60) + + compare := CompareVerificationValue(hashedVerificationInfo, verificationValue) + assert.True(t, compare) +} diff --git a/internal/data/receiver_wallets_state_machine.go b/internal/data/receiver_wallets_state_machine.go new file mode 100644 index 000000000..b60696fd2 --- /dev/null +++ b/internal/data/receiver_wallets_state_machine.go @@ -0,0 +1,33 @@ +package data + +type ReceiversWalletStatus string + +const ( + DraftReceiversWalletStatus ReceiversWalletStatus = "DRAFT" + ReadyReceiversWalletStatus ReceiversWalletStatus = "READY" + RegisteredReceiversWalletStatus ReceiversWalletStatus = "REGISTERED" + FlaggedReceiversWalletStatus ReceiversWalletStatus = "FLAGGED" +) + +// TransitionTo transitions the receiver wallet status to the target state +func (status ReceiversWalletStatus) TransitionTo(targetState ReceiversWalletStatus) error { + return ReceiversWalletStateMachineWithInitialState(status).TransitionTo(targetState.State()) +} + +// ReceiversWalletStateMachineWithInitialState returns a state machine for ReceiversWallets initialized with the given state +func ReceiversWalletStateMachineWithInitialState(initialState ReceiversWalletStatus) *StateMachine { + transitions := []StateTransition{ + {From: DraftReceiversWalletStatus.State(), To: ReadyReceiversWalletStatus.State()}, // disbursement started + {From: ReadyReceiversWalletStatus.State(), To: RegisteredReceiversWalletStatus.State()}, // receiver signed up + {From: ReadyReceiversWalletStatus.State(), To: FlaggedReceiversWalletStatus.State()}, // flagged + {From: FlaggedReceiversWalletStatus.State(), To: ReadyReceiversWalletStatus.State()}, // unflagged + {From: RegisteredReceiversWalletStatus.State(), To: FlaggedReceiversWalletStatus.State()}, // flagged + {From: FlaggedReceiversWalletStatus.State(), To: RegisteredReceiversWalletStatus.State()}, // unflagged + } + + return NewStateMachine(initialState.State(), transitions) +} + +func (status ReceiversWalletStatus) State() State { + return State(status) +} diff --git a/internal/data/receiver_wallets_state_machine_test.go b/internal/data/receiver_wallets_state_machine_test.go new file mode 100644 index 000000000..a35684aac --- /dev/null +++ b/internal/data/receiver_wallets_state_machine_test.go @@ -0,0 +1,63 @@ +package data + +import "testing" + +func Test_ReceiversWalletStatus_TransitionTo(t *testing.T) { + tests := []struct { + name string + initial ReceiversWalletStatus + target ReceiversWalletStatus + wantErr bool + }{ + { + "DRAFT to READY", + DraftReceiversWalletStatus, + ReadyReceiversWalletStatus, + false, + }, + { + "READY to REGISTERED", + ReadyReceiversWalletStatus, + RegisteredReceiversWalletStatus, + false, + }, + { + "READY to FLAGGED", + ReadyReceiversWalletStatus, + FlaggedReceiversWalletStatus, + false, + }, + { + "FLAGGED to READY", + FlaggedReceiversWalletStatus, + ReadyReceiversWalletStatus, + false, + }, + { + "REGISTERED to FLAGGED", + RegisteredReceiversWalletStatus, + FlaggedReceiversWalletStatus, + false, + }, + { + "FLAGGED to REGISTERED", + FlaggedReceiversWalletStatus, + RegisteredReceiversWalletStatus, + false, + }, + { + "DRAFT to REGISTERED", + DraftReceiversWalletStatus, + RegisteredReceiversWalletStatus, + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.initial.TransitionTo(tt.target); (err != nil) != tt.wantErr { + t.Errorf("ReceiversWalletStatus.TransitionTo() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/data/receivers.go b/internal/data/receivers.go new file mode 100644 index 000000000..3e77b5418 --- /dev/null +++ b/internal/data/receivers.go @@ -0,0 +1,437 @@ +package data + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type Receiver struct { + ID string `json:"id" db:"id"` + Email *string `json:"email,omitempty" db:"email"` + PhoneNumber string `json:"phone_number,omitempty" db:"phone_number"` + ExternalID string `json:"external_id,omitempty" db:"external_id"` + CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"` + UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"` + ReceiverStats +} + +type ReceiverRegistrationRequest struct { + PhoneNumber string `json:"phone_number"` + OTP string `json:"otp"` + VerificationValue string `json:"verification"` + VerificationType VerificationField `json:"verification_type"` + ReCAPTCHAToken string `json:"recaptcha_token"` +} + +type ReceiverStats struct { + TotalPayments string `json:"total_payments,omitempty" db:"total_payments"` + SuccessfulPayments string `json:"successful_payments,omitempty" db:"successful_payments"` + FailedPayments string `json:"failed_payments,omitempty" db:"failed_payments"` + RemainingPayments string `json:"remaining_payments,omitempty" db:"remaining_payments"` + RegisteredWallets string `json:"registered_wallets,omitempty" db:"registered_wallets"` + ReceivedAmounts ReceivedAmounts `json:"received_amounts,omitempty" db:"received_amounts"` +} + +type Amount struct { + AssetCode string `json:"asset_code" db:"asset_code"` + AssetIssuer string `json:"asset_issuer" db:"asset_issuer"` + ReceivedAmount string `json:"received_amount" db:"received_amount"` +} + +var ( + DefaultReceiverSortField = SortFieldUpdatedAt + DefaultReceiverSortOrder = SortOrderDESC + AllowedReceiverFilters = []FilterKey{FilterKeyStatus, FilterKeyCreatedAtAfter, FilterKeyCreatedAtBefore} + AllowedReceiverSorts = []SortField{SortFieldCreatedAt, SortFieldUpdatedAt} +) + +type ReceiverModel struct{} + +type ReceiverInsert struct { + PhoneNumber string `db:"phone_number"` + ExternalId *string `db:"external_id"` +} + +type ReceiverUpdate struct { + Email string `db:"email"` + ExternalId string `db:"external_id"` +} + +type ReceivedAmounts []Amount + +// Scan implements the sql.Scanner interface. +func (ra *ReceivedAmounts) Scan(src interface{}) error { + var receivedAmounts sql.NullString + if err := (&receivedAmounts).Scan(src); err != nil { + return fmt.Errorf("error scanning status history value: %w", err) + } + + if receivedAmounts.Valid { + var shEntry []Amount + err := json.Unmarshal([]byte(receivedAmounts.String), &shEntry) + if err != nil { + return fmt.Errorf("error unmarshaling status_history column: %w", err) + } + + *ra = shEntry + } + + return nil +} + +// Get returns a RECEIVER matching the given ID. +func (r *ReceiverModel) Get(ctx context.Context, sqlExec db.SQLExecuter, id string) (*Receiver, error) { + receiver := Receiver{} + + query := ` + WITH receivers_cte AS ( + SELECT + r.id, + r.external_id, + r.phone_number, + r.email, + r.created_at, + r.updated_at + FROM receivers r + WHERE r.id = $1 + ), receiver_wallets_cte AS ( + SELECT + rc.id as receiver_id, + COUNT(rw) FILTER(WHERE rw.status = 'REGISTERED') as registered_wallets + FROM receivers_cte rc + JOIN receiver_wallets rw ON rc.id = rw.receiver_id + GROUP BY rc.id + ), receiver_stats AS ( + SELECT + rc.id as receiver_id, + COUNT(p) as total_payments, + COUNT(p) FILTER(WHERE p.status = 'SUCCESS') as successful_payments, + COUNT(p) FILTER(WHERE p.status = 'FAILED') as failed_payments, + COUNT(p) FILTER(WHERE p.status IN ('DRAFT', 'READY', 'PENDING', 'PAUSED')) as remaining_payments, + a.code as asset_code, + a.issuer as asset_issuer, + COALESCE(SUM(p.amount) FILTER(WHERE p.asset_id = a.id AND p.status = 'SUCCESS'), '0') as received_amount + FROM receivers_cte rc + JOIN payments p ON rc.id = p.receiver_id + JOIN disbursements d ON p.disbursement_id = d.id + JOIN assets a ON a.id = p.asset_id + GROUP BY (rc.id, a.code, a.issuer) + ), receiver_stats_aggregate AS ( + SELECT + rs.receiver_id, + SUM(rs.total_payments) as total_payments, + SUM(rs.successful_payments) as successful_payments, + SUM(rs.failed_payments) as failed_payments, + SUM(rs.remaining_payments) as remaining_payments, + jsonb_agg(jsonb_build_object('asset_code', rs.asset_code, 'asset_issuer', rs.asset_issuer, 'received_amount', rs.received_amount::text)) as received_amounts + FROM receiver_stats rs + GROUP BY (rs.receiver_id) + ) + SELECT + rc.id, + rc.external_id, + COALESCE(rc.email, '') as email, + rc.phone_number, + rc.created_at, + rc.updated_at, + COALESCE(total_payments, 0) as total_payments, + COALESCE(successful_payments, 0) as successful_payments, + COALESCE(rs.failed_payments, '0') as failed_payments, + COALESCE(rs.remaining_payments, '0') as remaining_payments, + rs.received_amounts, + COALESCE(rw.registered_wallets, 0) as registered_wallets + FROM receivers_cte rc + LEFT JOIN receiver_stats_aggregate rs ON rs.receiver_id = rc.id + LEFT JOIN receiver_wallets_cte rw ON rw.receiver_id = rc.id + ` + + err := sqlExec.GetContext(ctx, &receiver, query, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } else { + return nil, fmt.Errorf("error querying receiver ID: %w", err) + } + } + + return &receiver, nil +} + +// Count returns the number of receivers matching the given query parameters. +func (r *ReceiverModel) Count(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) (int, error) { + var count int + baseQuery := ` + SELECT + COUNT(DISTINCT r.id) + FROM receivers r + LEFT JOIN receiver_wallets rw ON rw.receiver_id = r.id + ` + query, params := newReceiverQuery(baseQuery, queryParams, false, sqlExec) + + err := sqlExec.GetContext(ctx, &count, query, params...) + if err != nil { + return 0, fmt.Errorf("error counting payments: %w", err) + } + + return count, nil +} + +// GetAll returns all RECEIVERS matching the given query parameters. +func (r *ReceiverModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, queryParams *QueryParams) ([]Receiver, error) { + receivers := []Receiver{} + + baseQuery := ` + WITH receivers_cte AS ( + %s + ), registered_receiver_wallets_count_cte AS ( + SELECT + rc.id as receiver_id, + COUNT(rw) FILTER(WHERE rw.status = 'REGISTERED') as registered_wallets + FROM receivers_cte rc + JOIN receiver_wallets rw ON rc.id = rw.receiver_id + GROUP BY rc.id + ), receiver_stats AS ( + SELECT + rc.id as receiver_id, + COUNT(p) as total_payments, + COUNT(p) FILTER(WHERE p.status = 'SUCCESS') as successful_payments, + COUNT(p) FILTER(WHERE p.status = 'FAILED') as failed_payments, + COUNT(p) FILTER(WHERE p.status IN ('DRAFT', 'READY', 'PENDING', 'PAUSED')) as remaining_payments, + a.code as asset_code, + a.issuer as asset_issuer, + COALESCE(SUM(p.amount) FILTER(WHERE p.asset_id = a.id AND p.status = 'SUCCESS'), '0') as received_amount + FROM receivers_cte rc + JOIN payments p ON rc.id = p.receiver_id + JOIN disbursements d ON p.disbursement_id = d.id + JOIN assets a ON a.id = p.asset_id + GROUP BY (rc.id, a.code, a.issuer) + ), receiver_stats_aggregate AS ( + SELECT + rs.receiver_id, + SUM(rs.total_payments) as total_payments, + SUM(rs.successful_payments) as successful_payments, + SUM(rs.failed_payments) as failed_payments, + SUM(rs.remaining_payments) as remaining_payments, + jsonb_agg(jsonb_build_object('asset_code', rs.asset_code, 'asset_issuer', rs.asset_issuer, 'received_amount', rs.received_amount::text)) as received_amounts + FROM receiver_stats rs + GROUP BY (rs.receiver_id) + ) + SELECT + distinct(r.id), + r.external_id, + COALESCE(r.email, '') as email, + r.phone_number, + r.created_at, + r.updated_at, + COALESCE(total_payments, 0) as total_payments, + COALESCE(successful_payments, 0) as successful_payments, + COALESCE(rs.failed_payments, '0') as failed_payments, + COALESCE(rs.remaining_payments, '0') as remaining_payments, + rs.received_amounts, + COALESCE(rrwc.registered_wallets, 0) as registered_wallets + FROM receivers_cte r + LEFT JOIN receiver_stats_aggregate rs ON rs.receiver_id = r.id + LEFT JOIN receiver_wallets rw ON rw.receiver_id = r.id + LEFT JOIN registered_receiver_wallets_count_cte rrwc ON rrwc.receiver_id = r.id + ` + + receiverQuery := ` + SELECT + r.id, + r.email, + r.external_id, + r.phone_number, + r.created_at, + r.updated_at + FROM + receivers r + ` + + query := fmt.Sprintf(baseQuery, receiverQuery) + query, params := newReceiverQuery(query, queryParams, true, sqlExec) + + err := sqlExec.SelectContext(ctx, &receivers, query, params...) + if err != nil { + return nil, fmt.Errorf("error querying receivers: %w", err) + } + + return receivers, nil +} + +// newReceiverQuery generates the full query and parameters for a receiver search query +func newReceiverQuery(baseQuery string, queryParams *QueryParams, paginated bool, sqlExec db.SQLExecuter) (string, []interface{}) { + qb := NewQueryBuilder(baseQuery) + if queryParams.Query != "" { + q := "%" + queryParams.Query + "%" + qb.AddCondition("(r.id ILIKE ? OR r.phone_number ILIKE ? OR r.email ILIKE ?)", q, q, q) + } + if queryParams.Filters[FilterKeyStatus] != nil { + status := queryParams.Filters[FilterKeyStatus].(ReceiversWalletStatus) + qb.AddCondition("rw.status = ?", status) + } + if queryParams.Filters[FilterKeyCreatedAtAfter] != nil { + qb.AddCondition("r.created_at >= ?", queryParams.Filters[FilterKeyCreatedAtAfter]) + } + if queryParams.Filters[FilterKeyCreatedAtBefore] != nil { + qb.AddCondition("r.created_at <= ?", queryParams.Filters[FilterKeyCreatedAtBefore]) + } + if paginated { + qb.AddSorting(queryParams.SortBy, queryParams.SortOrder, "r") + qb.AddPagination(queryParams.Page, queryParams.PageLimit) + } + query, params := qb.Build() + return sqlExec.Rebind(query), params +} + +type ReceiverIDs []string + +// ParseReceiverIDs return the array of receivers IDs. +func (r *ReceiverModel) ParseReceiverIDs(receivers []Receiver) ReceiverIDs { + receiverIds := make(ReceiverIDs, 0) + + for _, receiver := range receivers { + receiverIds = append(receiverIds, receiver.ID) + } + + return receiverIds +} + +// Insert inserts a new receiver into the database. +func (r *ReceiverModel) Insert(ctx context.Context, sqlExec db.SQLExecuter, insert ReceiverInsert) (*Receiver, error) { + query := ` + INSERT INTO receivers ( + phone_number, + external_id + ) VALUES ( + $1, + $2 + ) RETURNING + id, + phone_number, + external_id, + created_at, + updated_at + ` + + var receiver Receiver + err := sqlExec.GetContext(ctx, &receiver, query, insert.PhoneNumber, insert.ExternalId) + if err != nil { + return nil, fmt.Errorf("error inserting receiver: %w", err) + } + + return &receiver, nil +} + +// Update updates the receiver Email and/or External ID. +func (r *ReceiverModel) Update(ctx context.Context, sqlExec db.SQLExecuter, ID string, receiverUpdate ReceiverUpdate) error { + if receiverUpdate.Email == "" && receiverUpdate.ExternalId == "" { + return fmt.Errorf("provide at least one of these values: Email or ExternalID") + } + + args := []interface{}{} + fields := []string{} + if receiverUpdate.Email != "" { + if err := utils.ValidateEmail(receiverUpdate.Email); err != nil { + return fmt.Errorf("error validating email: %w", err) + } + + fields = append(fields, "email = ?") + args = append(args, receiverUpdate.Email) + } + + if receiverUpdate.ExternalId != "" { + fields = append(fields, "external_id = ?") + args = append(args, receiverUpdate.ExternalId) + } + + args = append(args, ID) + + query := ` + UPDATE + receivers + SET + %s + WHERE + id = ? + ` + + query = sqlExec.Rebind(fmt.Sprintf(query, strings.Join(fields, ", "))) + + _, err := sqlExec.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("error updating receiver: %w", err) + } + + return nil +} + +// GetByPhoneNumbers search for receivers by phone numbers +func (r *ReceiverModel) GetByPhoneNumbers(ctx context.Context, sqlExec db.SQLExecuter, ids []string) ([]*Receiver, error) { + receivers := []*Receiver{} + + query := ` + SELECT + r.id, + r.phone_number, + r.external_id, + r.created_at, + r.updated_at + FROM receivers r + WHERE r.phone_number = ANY($1) + ` + err := sqlExec.SelectContext(ctx, &receivers, query, pq.Array(ids)) + if err != nil { + return nil, fmt.Errorf("error fetching receiver ids by phone numbers: %w", err) + } + return receivers, nil +} + +// DeleteByPhoneNumber deletes a receiver by phone number. It also deletes the associated entries in other tables: +// messages, payments, receiver_verifications, receiver_wallets, receivers, disbursements, submitter_transactions +func (r *ReceiverModel) DeleteByPhoneNumber(ctx context.Context, dbConnectionPool db.DBConnectionPool, phoneNumber string) error { + return db.RunInTransaction(ctx, dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + query := "SELECT id FROM receivers WHERE phone_number = $1" + var receiverID string + + err := dbTx.GetContext(ctx, &receiverID, query, phoneNumber) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrRecordNotFound + } + return fmt.Errorf("error fetching receiver by phone number %s: %w", phoneNumber, err) + } + + type QueryWithParams struct { + Query string + Params []interface{} + } + + queries := []QueryWithParams{ + {"DELETE FROM messages WHERE receiver_id = $1", []interface{}{receiverID}}, + {"DELETE FROM receiver_verifications WHERE receiver_id = $1", []interface{}{receiverID}}, + {"DELETE FROM payments WHERE receiver_id = $1", []interface{}{receiverID}}, + {"DELETE FROM receiver_wallets WHERE receiver_id = $1", []interface{}{receiverID}}, + {"DELETE FROM receivers WHERE id = $1", []interface{}{receiverID}}, + {"DELETE FROM disbursements WHERE id NOT IN (SELECT DISTINCT disbursement_id FROM payments)", nil}, + } + + for _, qwp := range queries { + _, err = dbTx.ExecContext(ctx, qwp.Query, qwp.Params...) + if err != nil { + return fmt.Errorf("error executing query %q: %w", qwp.Query, err) + } + } + + return nil + }) +} diff --git a/internal/data/receivers_test.go b/internal/data/receivers_test.go new file mode 100644 index 000000000..a43162c0c --- /dev/null +++ b/internal/data/receivers_test.go @@ -0,0 +1,1169 @@ +package data + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReceiversModelGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursement := Disbursement{ + Status: DraftDisbursementStatus, + Asset: asset, + Country: country, + Wallet: wallet, + } + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + payment := Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Asset: *asset, + ReceiverWallet: receiverWallet, + } + + receiverModel := ReceiverModel{} + + t.Run("returns error when receiver does not exist", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + gotReceiver, err := receiverModel.Get(ctx, dbTx, "invalid_id") + require.Error(t, err) + require.ErrorIs(t, ErrRecordNotFound, err) + require.Nil(t, gotReceiver) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver without payments", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverModel.Get(ctx, dbTx, receiver.ID) + require.NoError(t, err) + + expected := Receiver{ + ID: receiver.ID, + ExternalID: receiver.ExternalID, + Email: receiver.Email, + PhoneNumber: receiver.PhoneNumber, + CreatedAt: receiver.CreatedAt, + UpdatedAt: receiver.UpdatedAt, + ReceiverStats: ReceiverStats{ + TotalPayments: "0", + SuccessfulPayments: "0", + FailedPayments: "0", + RemainingPayments: "0", + RegisteredWallets: "0", + ReceivedAmounts: nil, + }, + } + assert.Equal(t, expected, *actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver with payment", func(t *testing.T) { + disbursement.Name = "disbursement 1" + d := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverModel.Get(ctx, dbTx, receiver.ID) + require.NoError(t, err) + expected := Receiver{ + ID: receiver.ID, + ExternalID: receiver.ExternalID, + Email: receiver.Email, + PhoneNumber: receiver.PhoneNumber, + CreatedAt: receiver.CreatedAt, + UpdatedAt: receiver.UpdatedAt, + ReceiverStats: ReceiverStats{ + TotalPayments: "1", + SuccessfulPayments: "0", + FailedPayments: "0", + RemainingPayments: "1", + RegisteredWallets: "0", + ReceivedAmounts: []Amount{ + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "0", + }, + }, + }, + } + assert.Equal(t, expected, *actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver with successful payment", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement 1" + d := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + disbursement.Name = "disbursement 2" + d = CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = SuccessPaymentStatus + payment.Disbursement = d + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverModel.Get(ctx, dbTx, receiver.ID) + require.NoError(t, err) + expected := Receiver{ + ID: receiver.ID, + ExternalID: receiver.ExternalID, + Email: receiver.Email, + PhoneNumber: receiver.PhoneNumber, + CreatedAt: receiver.CreatedAt, + UpdatedAt: receiver.UpdatedAt, + ReceiverStats: ReceiverStats{ + TotalPayments: "2", + SuccessfulPayments: "1", + FailedPayments: "0", + RemainingPayments: "1", + RegisteredWallets: "0", + ReceivedAmounts: []Amount{ + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "50.0000000", + }, + }, + }, + } + assert.Equal(t, expected, *actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver with multiple assets", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement 1" + d := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + asset2 := CreateAssetFixture(t, ctx, dbConnectionPool, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + disbursement.Name = "disbursement 2" + disbursement.Asset = asset2 + d = CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = SuccessPaymentStatus + payment.Disbursement = d + payment.Asset = *asset2 + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverModel.Get(ctx, dbTx, receiver.ID) + require.NoError(t, err) + expected := Receiver{ + ID: receiver.ID, + ExternalID: receiver.ExternalID, + Email: receiver.Email, + PhoneNumber: receiver.PhoneNumber, + CreatedAt: receiver.CreatedAt, + UpdatedAt: receiver.UpdatedAt, + ReceiverStats: ReceiverStats{ + TotalPayments: "2", + SuccessfulPayments: "1", + FailedPayments: "0", + RemainingPayments: "1", + RegisteredWallets: "0", + ReceivedAmounts: []Amount{ + { + AssetCode: "EURT", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "50.0000000", + }, + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "0", + }, + }, + }, + } + assert.Equal(t, expected, *actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver using db transaction", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement 1" + disbursement.Asset = asset + d := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + payment.Asset = *asset + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + // Initializing a new Tx. + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverModel.Get(ctx, dbTx, receiver.ID) + require.NoError(t, err) + expected := Receiver{ + ID: receiver.ID, + ExternalID: receiver.ExternalID, + Email: receiver.Email, + PhoneNumber: receiver.PhoneNumber, + CreatedAt: receiver.CreatedAt, + UpdatedAt: receiver.UpdatedAt, + ReceiverStats: ReceiverStats{ + TotalPayments: "1", + SuccessfulPayments: "0", + FailedPayments: "0", + RemainingPayments: "1", + RegisteredWallets: "0", + ReceivedAmounts: []Amount{ + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "0", + }, + }, + }, + } + + assert.Equal(t, expected, *actual) + + // Commit the transaction. + commitErr := dbTx.Commit() + require.NoError(t, commitErr) + }) +} + +func Test_ReceiversModelCount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverModel := ReceiverModel{} + + t.Run("returns 0 when no receivers exist", func(t *testing.T) { + dbTx, innerErr := dbConnectionPool.BeginTxx(ctx, nil) + // Defer a rollback in case anything fails. + defer func() { + innerErr = dbTx.Rollback() + require.Error(t, innerErr, "not in transaction") + }() + + count, innerErr := receiverModel.Count(ctx, dbTx, &QueryParams{}) + require.NoError(t, innerErr) + assert.Equal(t, 0, count) + + innerErr = dbTx.Commit() + require.NoError(t, innerErr) + }) + + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, DraftReceiversWalletStatus) + CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, RegisteredReceiversWalletStatus) + + t.Run("returns count of receivers", func(t *testing.T) { + dbTx, innerErr := dbConnectionPool.BeginTxx(ctx, nil) + // Defer a rollback in case anything fails. + defer func() { + innerErr = dbTx.Rollback() + require.Error(t, innerErr, "not in transaction") + }() + + count, innerErr := receiverModel.Count(ctx, dbTx, &QueryParams{}) + require.NoError(t, innerErr) + assert.Equal(t, 2, count) + + innerErr = dbTx.Commit() + require.NoError(t, innerErr) + }) + + t.Run("returns count of receivers with filter", func(t *testing.T) { + dbTx, innerErr := dbConnectionPool.BeginTxx(ctx, nil) + // Defer a rollback in case anything fails. + defer func() { + innerErr = dbTx.Rollback() + require.Error(t, innerErr, "not in transaction") + }() + + filters := map[FilterKey]interface{}{ + FilterKeyStatus: DraftReceiversWalletStatus, + } + count, innerErr := receiverModel.Count(ctx, dbTx, &QueryParams{Filters: filters}) + require.NoError(t, innerErr) + assert.Equal(t, 1, count) + + innerErr = dbTx.Commit() + require.NoError(t, innerErr) + }) + + t.Run("returns count of receivers with session", func(t *testing.T) { + // Initializing a new Tx. + dbTx, innerErr := dbConnectionPool.BeginTxx(ctx, nil) + // Defer a rollback in case anything fails. + defer func() { + innerErr = dbTx.Rollback() + require.Error(t, innerErr, "not in transaction") + }() + + count, innerErr := receiverModel.Count(ctx, dbTx, &QueryParams{}) + require.NoError(t, innerErr) + assert.Equal(t, 2, count) + + // Commit the transaction. + innerErr = dbTx.Commit() + require.NoError(t, innerErr) + }) +} + +func Test_ReceiversModelGetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverModel := ReceiverModel{} + + t.Run("returns empty list when no receiver exist", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + receivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{}) + require.NoError(t, err) + assert.Equal(t, 0, len(receivers)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + date := time.Date(2023, 1, 10, 23, 40, 20, 1431, time.UTC) + receiver1Email := "receiver1@mock.com" + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{ + Email: &receiver1Email, + PhoneNumber: "+99991111", + ExternalID: "external-id-1", + CreatedAt: &date, + UpdatedAt: &date, + }) + + date = time.Date(2023, 3, 10, 23, 40, 20, 1431, time.UTC) + receiver2Email := "receiver2@mock.com" + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{ + Email: &receiver2Email, + PhoneNumber: "+99992222", + ExternalID: "external-id-2", + CreatedAt: &date, + UpdatedAt: &date, + }) + + CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, DraftReceiversWalletStatus) + CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, RegisteredReceiversWalletStatus) + + t.Run("returns receiver successfully", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}) + require.NoError(t, err) + assert.Equal(t, 2, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external-id-1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0" + }, + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external-id-2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"1" + } + ]`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano), + receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receivers successfully with limit", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{ + SortBy: SortFieldCreatedAt, + SortOrder: SortOrderASC, + Page: 1, + PageLimit: 1, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external-id-1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0" + } + ]`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receivers successfully with offset", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{ + SortBy: SortFieldCreatedAt, + SortOrder: SortOrderASC, + Page: 2, + PageLimit: 1, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external-id-2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"1" + } + ]`, receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receivers successfully with status filter", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + if err != nil { + err = dbTx.Rollback() + require.NoError(t, err, "not in transaction") + } + }() + + filters := map[FilterKey]interface{}{ + FilterKeyStatus: DraftReceiversWalletStatus, + } + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Filters: filters}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external-id-1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0" + } + ]`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receivers successfully with query filter email", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Query: receiver1Email}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external-id-1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0" + } + ]`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receivers successfully with query filter phone number", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Query: "+99992222"}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external-id-2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"1" + } + ]`, receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receivers successfully with date filter", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + filters := map[FilterKey]interface{}{ + FilterKeyCreatedAtAfter: "2023-01-01", + FilterKeyCreatedAtBefore: "2023-03-01", + } + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{Filters: filters}) + require.NoError(t, err) + assert.Equal(t, 1, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external-id-1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0" + } + ]`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)) + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver successfully with session", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actualReceivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}) + require.NoError(t, err) + assert.Equal(t, 2, len(actualReceivers)) + + ar, err := json.Marshal(actualReceivers) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external-id-1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0" + }, + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external-id-2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"1" + } + ]`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano), + receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano)) + + assert.JSONEq(t, wantJson, string(ar)) + + // Commit the transaction. + commitErr := dbTx.Commit() + require.NoError(t, commitErr) + }) +} + +func Test_ReceiversModel_GetAll_makeSureReceiversWithMultipleWalletsWillReturnASingleResult(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverModel := ReceiverModel{} + + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + wallet2 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet2", "https://www.wallet2.com", "www.wallet2.com", "wallet2://") + + receiver1Email := "receiver1@mock.com" + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{ + Email: &receiver1Email, + PhoneNumber: "+99991111", + ExternalID: "external-id-1", + }) + + CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, ReadyReceiversWalletStatus) + CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, RegisteredReceiversWalletStatus) + + receivers, err := receiverModel.GetAll(ctx, dbConnectionPool, &QueryParams{}) + require.NoError(t, err) + + assert.Len(t, receivers, 1) +} + +func Test_ReceiversModel_ParseReceiverIDs(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverModel := ReceiverModel{} + + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + receivers, err := receiverModel.GetAll(ctx, dbTx, &QueryParams{SortBy: SortFieldCreatedAt, SortOrder: SortOrderASC}) + require.NoError(t, err) + + receiverIds := receiverModel.ParseReceiverIDs(receivers) + expectedIds := ReceiverIDs{receiver1.ID, receiver2.ID} + assert.Equal(t, expectedIds, receiverIds) + + err = dbTx.Commit() + require.NoError(t, err) +} + +func Test_DeleteByPhoneNumber(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + models, err := NewModels(dbConnectionPool) + require.NoError(t, err) + + // 0. returns ErrNotFound for users that don't exist: + t.Run("User does not exist", func(t *testing.T) { + err = models.Receiver.DeleteByPhoneNumber(ctx, dbConnectionPool, "+14152222222") + require.ErrorIs(t, err, ErrRecordNotFound) + }) + + // 1. Create country, asset, and wallet (won't be deleted) + country := CreateCountryFixture(t, ctx, dbConnectionPool, "ATL", "Atlantis") + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "FOO1", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "walletA", "https://www.a.com", "www.a.com", "a://") + + // 2. Create receiverX (that will be deleted) and all receiverX dependent resources that will also be deleted: + receiverX := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWalletX := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverX.ID, wallet.ID, DraftReceiversWalletStatus) + _ = CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiverX.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + messageX := CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiverX.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWalletX.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + disbursement1 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: wallet, + Status: ReadyDisbursementStatus, + Asset: asset, + }) + paymentX1 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletX, + Disbursement: disbursement1, + Asset: *asset, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + // 3. Create receiverY (that will not be deleted) and all receiverY dependent resources that will not be deleted: + receiverY := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiverWalletY := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiverY.ID, wallet.ID, DraftReceiversWalletStatus) + _ = CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, ReceiverVerificationInsert{ + ReceiverID: receiverY.ID, + VerificationField: VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + messageY := CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiverY.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWalletY.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + disbursement2 := CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &Disbursement{ + Country: country, + Wallet: wallet, + Status: ReadyDisbursementStatus, + Asset: asset, + }) + paymentY2 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletY, + Disbursement: disbursement2, + Asset: *asset, + Status: ReadyPaymentStatus, + Amount: "1", + }) + + paymentX2 := CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &Payment{ + ReceiverWallet: receiverWalletX, + Disbursement: disbursement2, + Asset: *asset, + Status: ReadyPaymentStatus, + Amount: "1", + }) // This payment will be deleted along with the remaining receiverX-related data + + // 4. Delete receiverX + err = models.Receiver.DeleteByPhoneNumber(ctx, dbConnectionPool, receiverX.PhoneNumber) + require.NoError(t, err) + + type testCase struct { + name string + query string + args []interface{} + wantExists bool + } + + // 5. Prepare assertions to make sure `DeleteByPhoneNumber` DID DELETE receiverX-related data: + didDeleteTestCases := []testCase{ + { + name: "DID DELETE: receiverX", + query: "SELECT EXISTS(SELECT 1 FROM receivers WHERE id = $1)", + args: []interface{}{receiverX.ID}, + wantExists: false, + }, + { + name: "DID DELETE: receiverWalletX", + query: "SELECT EXISTS(SELECT 1 FROM receiver_wallets WHERE id = $1)", + args: []interface{}{receiverWalletX.ID}, + wantExists: false, + }, + { + name: "DID DELETE: receiverVerificationX", + query: "SELECT EXISTS(SELECT 1 FROM receiver_verifications WHERE receiver_id = $1)", + args: []interface{}{receiverX.ID}, + wantExists: false, + }, + { + name: "DID DELETE: messageX", + query: "SELECT EXISTS(SELECT 1 FROM messages WHERE id = $1)", + args: []interface{}{messageX.ID}, + wantExists: false, + }, + { + name: "DID DELETE: paymentX", + query: "SELECT EXISTS(SELECT 1 FROM payments WHERE id = ANY($1))", + args: []interface{}{pq.Array([]string{paymentX1.ID, paymentX2.ID})}, + wantExists: false, + }, + { + name: "DID DELETE: disbursement1", + query: "SELECT EXISTS(SELECT 1 FROM disbursements WHERE id = $1)", + args: []interface{}{disbursement1.ID}, + wantExists: false, + }, + } + + // 6. Prepare assertions to make sure `DeleteByPhoneNumber` DID NOT DELETE receiverY-related data: + didNotDeleteTestCases := []testCase{ + { + name: "DID NOT DELETE: receiverY", + query: "SELECT EXISTS(SELECT 1 FROM receivers WHERE id = $1)", + args: []interface{}{receiverY.ID}, + wantExists: true, + }, + { + name: "DID NOT DELETE: receiverWalletY", + query: "SELECT EXISTS(SELECT 1 FROM receiver_wallets WHERE id = $1)", + args: []interface{}{receiverWalletY.ID}, + wantExists: true, + }, + { + name: "DID NOT DELETE: receiverVerificationY", + query: "SELECT EXISTS(SELECT 1 FROM receiver_verifications WHERE receiver_id = $1)", + args: []interface{}{receiverY.ID}, + wantExists: true, + }, + { + name: "DID NOT DELETE: messageY", + query: "SELECT EXISTS(SELECT 1 FROM messages WHERE id = $1)", + args: []interface{}{messageY.ID}, + wantExists: true, + }, + { + name: "DID NOT DELETE: paymentY2", + query: "SELECT EXISTS(SELECT 1 FROM payments WHERE id = $1)", + args: []interface{}{paymentY2.ID}, + wantExists: true, + }, + { + name: "DID NOT DELETE: paymentX2", + query: "SELECT EXISTS(SELECT 1 FROM disbursements WHERE id = $1)", + args: []interface{}{disbursement2.ID}, + wantExists: true, + }, + } + + // 7. Run assertions + testCases := append(didDeleteTestCases, didNotDeleteTestCases...) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var exists bool + err = dbConnectionPool.QueryRowxContext(ctx, tc.query, tc.args...).Scan(&exists) + require.NoError(t, err) + require.Equal(t, tc.wantExists, exists) + }) + } +} + +func Test_ReceiversModel_Update(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverModel := ReceiverModel{} + + email, externalID := "receiver@email.com", "externalID" + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{Email: &email, ExternalID: externalID}) + + resetReceiver := func(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, receiverID string) { + q := ` + UPDATE receivers SET email = $1, external_id = $2 WHERE id = $3 + ` + _, err = sqlExec.ExecContext(ctx, q, email, externalID, receiverID) + require.NoError(t, err) + } + + t.Run("returns error when no value is provided", func(t *testing.T) { + resetReceiver(t, ctx, dbConnectionPool, receiver.ID) + + err = receiverModel.Update(ctx, dbConnectionPool, receiver.ID, ReceiverUpdate{ + Email: "", + ExternalId: "", + }) + assert.EqualError(t, err, "provide at least one of these values: Email or ExternalID") + }) + + t.Run("returns error when email is invalid", func(t *testing.T) { + resetReceiver(t, ctx, dbConnectionPool, receiver.ID) + + err = receiverModel.Update(ctx, dbConnectionPool, receiver.ID, ReceiverUpdate{ + Email: "invalid", + ExternalId: "", + }) + assert.EqualError(t, err, `error validating email: the provided email is not valid`) + }) + + t.Run("updates email name successfully", func(t *testing.T) { + resetReceiver(t, ctx, dbConnectionPool, receiver.ID) + + receiver, err = receiverModel.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, email, *receiver.Email) + assert.Equal(t, externalID, receiver.ExternalID) + + err = receiverModel.Update(ctx, dbConnectionPool, receiver.ID, ReceiverUpdate{ + Email: "updated_email@email.com", + ExternalId: "", + }) + require.NoError(t, err) + + receiver, err = receiverModel.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.NotEqual(t, email, *receiver.Email) + assert.Equal(t, "updated_email@email.com", *receiver.Email) + assert.Equal(t, externalID, receiver.ExternalID) + }) + + t.Run("updates external ID successfully", func(t *testing.T) { + resetReceiver(t, ctx, dbConnectionPool, receiver.ID) + + receiver, err = receiverModel.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, email, *receiver.Email) + assert.Equal(t, externalID, receiver.ExternalID) + + err := receiverModel.Update(ctx, dbConnectionPool, receiver.ID, ReceiverUpdate{ + Email: "updated_email@email.com", + ExternalId: "newExternalID", + }) + require.NoError(t, err) + + receiver, err = receiverModel.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.NotEqual(t, email, *receiver.Email) + assert.Equal(t, "updated_email@email.com", *receiver.Email) + assert.NotEqual(t, externalID, receiver.ExternalID) + assert.Equal(t, "newExternalID", receiver.ExternalID) + }) +} diff --git a/internal/data/receivers_wallet.go b/internal/data/receivers_wallet.go new file mode 100644 index 000000000..7d37cd301 --- /dev/null +++ b/internal/data/receivers_wallet.go @@ -0,0 +1,464 @@ +package data + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/stellar/go/network" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + + "github.com/lib/pq" +) + +const OTPExpirationTimeMinutes = 30 + +type ReceiversWalletStatusHistoryEntry struct { + Status ReceiversWalletStatus `json:"status"` + Timestamp time.Time `json:"timestamp"` +} + +type ReceiverWallet struct { + ID string `json:"id" db:"id"` + Receiver Receiver `json:"receiver" db:"receiver"` + Wallet Wallet `json:"wallet" db:"wallet"` + StellarAddress string `json:"stellar_address,omitempty" db:"stellar_address"` + StellarMemo string `json:"stellar_memo,omitempty" db:"stellar_memo"` + StellarMemoType string `json:"stellar_memo_type,omitempty" db:"stellar_memo_type"` + Status ReceiversWalletStatus `json:"status" db:"status"` + StatusHistory []ReceiversWalletStatusHistoryEntry `json:"status_history,omitempty" db:"status_history"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + OTP string `json:"-" db:"otp"` + OTPCreatedAt *time.Time `json:"-" db:"otp_created_at"` + OTPConfirmedAt *time.Time `json:"otp_confirmed_at,omitempty" db:"otp_confirmed_at"` + InvitedAt *time.Time `json:"invited_at,omitempty" db:"invited_at"` + LastSmsSent *time.Time `json:"last_sms_sent,omitempty" db:"last_sms_sent"` + ReceiverWalletStats +} + +type ReceiverWalletStats struct { + TotalPayments string `json:"total_payments,omitempty" db:"total_payments"` + PaymentsReceived string `json:"payments_received,omitempty" db:"payments_received"` + FailedPayments string `json:"failed_payments,omitempty" db:"failed_payments"` + RemainingPayments string `json:"remaining_payments,omitempty" db:"remaining_payments"` + ReceivedAmounts ReceivedAmounts `json:"received_amounts,omitempty" db:"received_amounts"` +} + +type ReceiverWalletModel struct { + dbConnectionPool db.DBConnectionPool +} + +type ReceiverWalletInsert struct { + ReceiverID string + WalletID string +} + +func (rw *ReceiverWalletModel) GetWithReceiverIds(ctx context.Context, sqlExec db.SQLExecuter, receiverIds ReceiverIDs) ([]ReceiverWallet, error) { + receiverWallets := []ReceiverWallet{} + query := ` + WITH receiver_wallets_cte AS ( + SELECT + rw.id, + rw.receiver_id, + rw.stellar_address, + rw.stellar_memo, + rw.stellar_memo_type, + rw.status, + rw.created_at, + rw.updated_at, + w.id as wallet_id, + w.name as wallet_name, + w.homepage as wallet_homepage, + w.sep_10_client_domain as wallet_sep_10_client_domain + FROM receiver_wallets rw + JOIN wallets w ON rw.wallet_id = w.id + WHERE rw.receiver_id = ANY($1::varchar[]) + ), receiver_wallets_stats AS ( + SELECT + rwc.id as receiver_wallet_id, + COUNT(p) as total_payments, + COUNT(p) FILTER(WHERE p.status = 'SUCCESS') as payments_received, + COUNT(p) FILTER(WHERE p.status = 'FAILED') as failed_payments, + COUNT(p) FILTER(WHERE p.status IN ('DRAFT', 'READY', 'PENDING', 'PAUSED')) as remaining_payments, + a.code as asset_code, + a.issuer as asset_issuer, + COALESCE(SUM(p.amount) FILTER(WHERE p.asset_id = a.id AND p.status = 'SUCCESS'), '0') as received_amount + FROM receiver_wallets_cte rwc + JOIN payments p ON rwc.receiver_id = p.receiver_id + JOIN disbursements d ON p.disbursement_id = d.id AND rwc.wallet_id = d.wallet_id + JOIN assets a ON a.id = p.asset_id + GROUP BY (rwc.id, a.code, a.issuer) + ), receiver_wallets_stats_aggregate AS ( + SELECT + rws.receiver_wallet_id as receiver_wallet_id, + SUM(rws.total_payments) as total_payments, + SUM(rws.payments_received) as payments_received, + SUM(rws.failed_payments) as failed_payments, + SUM(rws.remaining_payments) as remaining_payments, + jsonb_agg(jsonb_build_object('asset_code', rws.asset_code, 'asset_issuer', rws.asset_issuer, 'received_amount', rws.received_amount::text)) as received_amounts + FROM receiver_wallets_stats rws + GROUP BY (rws.receiver_wallet_id) + ), receiver_wallets_messages AS ( + SELECT + rwc.id as receiver_wallet_id, + MIN(m.created_at) as invited_at, + MAX(m.created_at) as last_sms_sent + FROM receiver_wallets_cte rwc + LEFT JOIN messages m ON rwc.id = m.receiver_wallet_id + WHERE m.status = 'SUCCESS' + GROUP BY (rwc.id) + ) + SELECT + rwc.id, + rwc.receiver_id as "receiver.id", + COALESCE(rwc.stellar_address, '') as stellar_address, + COALESCE(rwc.stellar_memo, '') as stellar_memo, + COALESCE(rwc.stellar_memo_type, '') as stellar_memo_type, + rwc.status, + rwc.created_at, + rwc.updated_at, + rwc.wallet_id as "wallet.id", + rwc.wallet_name as "wallet.name", + rwc.wallet_homepage as "wallet.homepage", + rwc.wallet_sep_10_client_domain as "wallet.sep_10_client_domain", + COALESCE(rws.total_payments, '0') as total_payments, + COALESCE(rws.payments_received, '0') as payments_received, + COALESCE(rws.failed_payments, '0') as failed_payments, + COALESCE(rws.remaining_payments, '0') as remaining_payments, + rws.received_amounts, + rwm.invited_at as invited_at, + rwm.last_sms_sent as last_sms_sent + FROM receiver_wallets_cte rwc + LEFT JOIN receiver_wallets_stats_aggregate rws ON rws.receiver_wallet_id = rwc.id + LEFT JOIN receiver_wallets_messages rwm ON rwm.receiver_wallet_id = rwc.id + ORDER BY rwc.created_at + ` + + err := sqlExec.SelectContext(ctx, &receiverWallets, query, pq.StringArray(receiverIds)) + if err != nil { + return nil, fmt.Errorf("error querying receivers wallets: %w", err) + } + + return receiverWallets, nil +} + +// GetByReceiverIDsAndWalletID returns a list of receiver wallets by receiver IDs and wallet ID. +func (rw *ReceiverWalletModel) GetByReceiverIDsAndWalletID(ctx context.Context, sqlExec db.SQLExecuter, receiverIds []string, walletId string) ([]*ReceiverWallet, error) { + receiverWallets := []*ReceiverWallet{} + query := ` + SELECT + rw.id, + rw.receiver_id as "receiver.id", + rw.wallet_id as "wallet.id", + rw.status + FROM receiver_wallets rw + WHERE rw.receiver_id = ANY($1) + AND rw.wallet_id = $2 + ` + err := sqlExec.SelectContext(ctx, &receiverWallets, query, pq.Array(receiverIds), walletId) + if err != nil { + return nil, fmt.Errorf("error querying receiver wallets: %w", err) + } + + return receiverWallets, nil +} + +func (rw *ReceiverWalletModel) GetAllPendingRegistration(ctx context.Context, daysSinceLastInvitationMessageSent, maxTries int) ([]*ReceiverWallet, error) { + const query = ` + WITH receiver_wallet_ids_invitation_message_sent_between_period AS ( + SELECT + m.receiver_wallet_id + FROM + messages m + WHERE + m.created_at >= $1 + GROUP BY + m.receiver_wallet_id + ), receiver_wallet_ids_reached_invitation_message_max_tries AS ( + SELECT + m.receiver_wallet_id + FROM + messages m + GROUP BY + m.receiver_wallet_id + HAVING + COUNT(*) >= $2 + ) + SELECT + rw.id, + r.id AS "receiver.id", + r.phone_number AS "receiver.phone_number", + r.email AS "receiver.email", + w.id AS "wallet.id", + w.name AS "wallet.name" + FROM + receiver_wallets rw + INNER JOIN receivers r ON r.id = rw.receiver_id + INNER JOIN wallets w ON w.id = rw.wallet_id + WHERE + rw.status = 'READY' + AND rw.id NOT IN ( + SELECT receiver_wallet_id FROM receiver_wallet_ids_invitation_message_sent_between_period + UNION + SELECT receiver_wallet_id FROM receiver_wallet_ids_reached_invitation_message_max_tries + ) + ` + + var ( + receiverWallets []*ReceiverWallet + err error + ) + + interval := time.Now().AddDate(0, 0, -daysSinceLastInvitationMessageSent).UTC() + err = rw.dbConnectionPool.SelectContext(ctx, &receiverWallets, query, interval, maxTries) + + if err != nil { + return nil, fmt.Errorf("error querying pending registration receiver wallets: %w", err) + } + + return receiverWallets, nil +} + +// UpdateOTPByReceiverPhoneNumberAndWalletDomain updates receiver wallet OTP if its not verified yet, +// and returns the number of updated rows. +func (rw *ReceiverWalletModel) UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx context.Context, receiverPhoneNumber, sep10ClientDomain, otp string) (numberOfUpdatedRows int, err error) { + query := ` + WITH rw_cte AS ( + SELECT + rw.id, + rw.otp_confirmed_at + FROM receiver_wallets rw + INNER JOIN receivers r ON rw.receiver_id = r.id + INNER JOIN wallets w ON rw.wallet_id = w.id + WHERE r.phone_number = $1 + AND w.sep_10_client_domain = $2 + AND rw.otp_confirmed_at IS NULL + ) + UPDATE + receiver_wallets + SET + otp = $3, + otp_created_at = NOW() + FROM rw_cte + WHERE + receiver_wallets.id = rw_cte.id + ` + + rows, err := rw.dbConnectionPool.ExecContext(ctx, query, receiverPhoneNumber, sep10ClientDomain, otp) + if err != nil { + return 0, fmt.Errorf("error updating receiver wallets otp: %w", err) + } + + updatedRowsAffected, err := rows.RowsAffected() + if err != nil { + return 0, fmt.Errorf("error getting updated rows of receiver wallets otp: %w", err) + } + + return int(updatedRowsAffected), nil +} + +// Insert inserts a new receiver wallet into the database. +func (rw *ReceiverWalletModel) Insert(ctx context.Context, sqlExec db.SQLExecuter, insert ReceiverWalletInsert) (string, error) { + var newId string + query := ` + INSERT INTO receiver_wallets (receiver_id, wallet_id) + VALUES ($1, $2) + RETURNING id + ` + + err := sqlExec.GetContext(ctx, &newId, query, insert.ReceiverID, insert.WalletID) + if err != nil { + return "", fmt.Errorf("error inserting receiver wallet: %w", err) + } + return newId, nil +} + +// GetByReceiverIDAndWalletDomain returns a receiver wallet that match the receiver ID and wallet domain. +func (rw *ReceiverWalletModel) GetByReceiverIDAndWalletDomain(ctx context.Context, receiverId string, walletDomain string, sqlExec db.SQLExecuter) (*ReceiverWallet, error) { + var receiverWallet ReceiverWallet + query := ` + SELECT + rw.id, + rw.receiver_id as "receiver.id", + rw.status, + COALESCE(rw.otp, '') as otp, + rw.otp_created_at, + w.id as "wallet.id", + w.name as "wallet.name", + w.sep_10_client_domain as "wallet.sep_10_client_domain" + FROM + receiver_wallets rw + JOIN + wallets w ON rw.wallet_id = w.id + WHERE + rw.receiver_id = $1 + AND + w.sep_10_client_domain = $2 + ` + + err := sqlExec.GetContext(ctx, &receiverWallet, query, receiverId, walletDomain) + if err != nil { + return nil, fmt.Errorf("error querying receiver wallet: %w", err) + } + + return &receiverWallet, nil +} + +// UpdateReceiverWallet updates informations from the receiver wallet. +func (rw *ReceiverWalletModel) UpdateReceiverWallet(ctx context.Context, receiverWallet ReceiverWallet, sqlExec db.SQLExecuter) error { + query := ` + UPDATE + receiver_wallets rw + SET + status = $1, + stellar_address = $2, + stellar_memo = $3, + stellar_memo_type = $4, + otp_confirmed_at = $5 + WHERE rw.id = $6 + ` + + _, err := sqlExec.ExecContext(ctx, query, + receiverWallet.Status, + receiverWallet.StellarAddress, + sql.NullString{String: receiverWallet.StellarMemo, Valid: receiverWallet.StellarMemo != ""}, + sql.NullString{String: receiverWallet.StellarMemoType, Valid: receiverWallet.StellarMemoType != ""}, + time.Now(), + receiverWallet.ID) + if err != nil { + return fmt.Errorf("error updating receiver wallet: %w", err) + } + + return nil +} + +// VerifyReceiverWalletOTP validates the receiver wallet OTP. +func (rw *ReceiverWalletModel) VerifyReceiverWalletOTP(ctx context.Context, networkPassphrase string, receiverWallet ReceiverWallet, otp string) error { + if networkPassphrase == network.TestNetworkPassphrase { + if otp == TestnetAlwaysValidOTP { + log.Ctx(ctx).Warnf("OTP is being approved because TestnetAlwaysValidOTP (%s) was used", TestnetAlwaysValidOTP) + return nil + } else if otp == TestnetAlwaysInvalidOTP { + log.Ctx(ctx).Errorf("OTP is being denied because TestnetAlwaysInvalidOTP (%s) was used", TestnetAlwaysInvalidOTP) + return fmt.Errorf("otp does not match with value saved in the database") + } + } + + if receiverWallet.OTP != otp { + return fmt.Errorf("otp does not match with value saved in the database") + } + + if receiverWallet.OTPCreatedAt.IsZero() { + return fmt.Errorf("otp does not have a valid created_at time") + } + + // TODO: use the commented out version deppending on the conclusion from https://stellarfoundation.slack.com/archives/C04C9MLM9UZ/p1686692315222719 + otpExpirationTime := receiverWallet.OTPCreatedAt.Add(time.Minute * OTPExpirationTimeMinutes) + if otpExpirationTime.Before(time.Now()) { + return fmt.Errorf("otp is expired") + } + return nil +} + +// UpdateStatusByDisbursementID updates the status of the receiver wallets associated with a disbursement. +func (rw *ReceiverWalletModel) UpdateStatusByDisbursementID(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string, from, to ReceiversWalletStatus) error { + if err := from.TransitionTo(to); err != nil { + return fmt.Errorf("cannot transition from %s to %s for receiver wallets for disbursement %s: %w", from, to, disbursementID, err) + } + query := ` + UPDATE receiver_wallets + SET status = $1, + status_history = array_append(status_history, create_receiver_wallet_status_history(NOW(), $1)) + WHERE id IN ( + SELECT rw.id + FROM payments p + JOIN receiver_wallets rw on p.receiver_wallet_id = rw.id + WHERE p.disbursement_id = $2 + AND rw.status = $3 + ) + ` + + result, err := sqlExec.ExecContext(ctx, query, to, disbursementID, from) + if err != nil { + return fmt.Errorf("error updating receiver_wallets for disbursement %s: %w", disbursementID, err) + } + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + log.Ctx(ctx).Infof("Set %d receiver_wallet from %s to %s for disbursement %s", numRowsAffected, from, to, disbursementID) + return nil +} + +func (rw *ReceiverWallet) statusHistoryFromByteArray(statusHistoryJSON pq.ByteaArray) error { + for _, sh := range statusHistoryJSON { + var shEntry ReceiversWalletStatusHistoryEntry + err := json.Unmarshal(sh, &shEntry) + if err != nil { + return fmt.Errorf("error unmarshaling status_history column: %w", err) + } + rw.StatusHistory = append(rw.StatusHistory, shEntry) + } + return nil +} + +func (rw *ReceiverWallet) statusHistoryJson() ([]string, error) { + var statusHistoryJSON []string + for _, sh := range rw.StatusHistory { + shJSONBytes, err := json.Marshal(sh) + if err != nil { + return nil, fmt.Errorf("error converting status history to json for receiver wallet %s: %w", rw.ID, err) + } + statusHistoryJSON = append(statusHistoryJSON, string(shJSONBytes)) + } + return statusHistoryJSON, nil +} + +// GetByStellarAccountAndMemo returns a receiver wallets that match the Stellar Account. +func (rw *ReceiverWalletModel) GetByStellarAccountAndMemo(ctx context.Context, stellarAccount, stellarMemo string) (*ReceiverWallet, error) { + // build query + var receiverWallets ReceiverWallet + query := ` + SELECT + rw.id, + rw.receiver_id as "receiver.id", + rw.status, + COALESCE(rw.stellar_address, '') as stellar_address, + COALESCE(rw.stellar_memo, '') as stellar_memo, + COALESCE(rw.stellar_memo_type, '') as stellar_memo_type, + COALESCE(rw.otp, '') as otp, + rw.otp_created_at, + w.id as "wallet.id", + w.name as "wallet.name", + w.homepage as "wallet.homepage" + FROM receiver_wallets rw + JOIN wallets w ON rw.wallet_id = w.id + WHERE rw.stellar_address = $1 + ` + + // append memo to query if it is not empty + args := []interface{}{stellarAccount} + if stellarMemo != "" { + query += " AND rw.stellar_memo = $2" + args = append(args, stellarMemo) + } else { + query += " AND (rw.stellar_memo IS NULL OR rw.stellar_memo = '')" + } + + // execute query + err := rw.dbConnectionPool.GetContext(ctx, &receiverWallets, query, args...) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("no receiver wallet could be found in GetByStellarAccountAndMemo: %w", ErrRecordNotFound) + } + return nil, fmt.Errorf("error querying receiver wallet: %w", err) + } + + return &receiverWallets, nil +} diff --git a/internal/data/receivers_wallet_test.go b/internal/data/receivers_wallet_test.go new file mode 100644 index 000000000..182a35af9 --- /dev/null +++ b/internal/data/receivers_wallet_test.go @@ -0,0 +1,1102 @@ +package data + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/lib/pq" + "github.com/stellar/go/network" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReceiversWalletModelGetWithReceiverId(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + + receiverWalletModel := ReceiverWalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns empty array when receiver does not exist", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, errReceiver := receiverWalletModel.GetWithReceiverIds(ctx, dbTx, ReceiverIDs{"invalid_id"}) + require.NoError(t, errReceiver) + require.Empty(t, actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns empty array when receiver does not have a receiver_wallet", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, errReceiver := receiverWalletModel.GetWithReceiverIds(ctx, dbTx, ReceiverIDs{receiver.ID}) + require.NoError(t, errReceiver) + require.Empty(t, actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + asset := CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiverWallet1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, DraftReceiversWalletStatus) + + message1 := CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &receiverWallet1.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message2 := CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &receiverWallet1.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + disbursementModel := DisbursementModel{dbConnectionPool: dbConnectionPool} + disbursement := Disbursement{ + Status: DraftDisbursementStatus, + Asset: asset, + Country: country, + } + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + paymentModel := PaymentModel{dbConnectionPool: dbConnectionPool} + payment := Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Asset: *asset, + } + + t.Run("returns receiver_wallet without payments", func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverWalletModel.GetWithReceiverIds(ctx, dbTx, ReceiverIDs{receiver.ID}) + require.NoError(t, err) + expected := []ReceiverWallet{ + { + ID: receiverWallet1.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet1.ID, + Name: wallet1.Name, + Homepage: wallet1.Homepage, + SEP10ClientDomain: wallet1.SEP10ClientDomain, + }, + StellarAddress: receiverWallet1.StellarAddress, + StellarMemo: receiverWallet1.StellarMemo, + StellarMemoType: receiverWallet1.StellarMemoType, + Status: receiverWallet1.Status, + CreatedAt: receiverWallet1.CreatedAt, + UpdatedAt: receiverWallet1.CreatedAt, + InvitedAt: &message1.CreatedAt, + LastSmsSent: &message2.CreatedAt, + ReceiverWalletStats: ReceiverWalletStats{ + TotalPayments: "0", + PaymentsReceived: "0", + FailedPayments: "0", + RemainingPayments: "0", + ReceivedAmounts: nil, + }, + }, + } + assert.Equal(t, expected, actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns receiver_wallet with payments", func(t *testing.T) { + disbursement.Name = "disbursement 1" + disbursement.Wallet = wallet1 + d := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = SuccessPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet1 + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + disbursement.Name = "disbursement 2" + disbursement.Wallet = wallet1 + d = CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet1 + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverWalletModel.GetWithReceiverIds(ctx, dbTx, ReceiverIDs{receiver.ID}) + require.NoError(t, err) + expected := []ReceiverWallet{ + { + ID: receiverWallet1.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet1.ID, + Name: wallet1.Name, + Homepage: wallet1.Homepage, + SEP10ClientDomain: wallet1.SEP10ClientDomain, + }, + StellarAddress: receiverWallet1.StellarAddress, + StellarMemo: receiverWallet1.StellarMemo, + StellarMemoType: receiverWallet1.StellarMemoType, + Status: receiverWallet1.Status, + CreatedAt: receiverWallet1.CreatedAt, + UpdatedAt: receiverWallet1.CreatedAt, + InvitedAt: &message1.CreatedAt, + LastSmsSent: &message2.CreatedAt, + ReceiverWalletStats: ReceiverWalletStats{ + TotalPayments: "2", + PaymentsReceived: "1", + FailedPayments: "0", + RemainingPayments: "1", + ReceivedAmounts: []Amount{ + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "50.0000000", + }, + }, + }, + }, + } + assert.Equal(t, expected, actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + t.Run("returns multiple receiver_wallets", func(t *testing.T) { + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + + disbursement.Name = "disbursement 1" + disbursement.Wallet = wallet1 + d := CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = SuccessPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet1 + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + disbursement.Name = "disbursement 2" + disbursement.Wallet = wallet1 + d = CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet1 + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + wallet2 := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet2", "https://www.wallet2.com", "www.wallet2.com", "wallet2://") + receiverWallet2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, DraftReceiversWalletStatus) + + message3 := CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message4 := CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + disbursement.Name = "disbursement 3" + disbursement.Wallet = wallet2 + d = CreateDisbursementFixture(t, ctx, dbConnectionPool, &disbursementModel, &disbursement) + + payment.Status = DraftPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet2 + CreatePaymentFixture(t, ctx, dbConnectionPool, &paymentModel, &payment) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverWalletModel.GetWithReceiverIds(ctx, dbTx, ReceiverIDs{receiver.ID}) + require.NoError(t, err) + expected := []ReceiverWallet{ + { + ID: receiverWallet1.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet1.ID, + Name: wallet1.Name, + Homepage: wallet1.Homepage, + SEP10ClientDomain: wallet1.SEP10ClientDomain, + }, + StellarAddress: receiverWallet1.StellarAddress, + StellarMemo: receiverWallet1.StellarMemo, + StellarMemoType: receiverWallet1.StellarMemoType, + Status: receiverWallet1.Status, + CreatedAt: receiverWallet1.CreatedAt, + UpdatedAt: receiverWallet1.CreatedAt, + InvitedAt: &message1.CreatedAt, + LastSmsSent: &message2.CreatedAt, + ReceiverWalletStats: ReceiverWalletStats{ + TotalPayments: "2", + PaymentsReceived: "1", + FailedPayments: "0", + RemainingPayments: "1", + ReceivedAmounts: []Amount{ + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "50.0000000", + }, + }, + }, + }, + { + ID: receiverWallet2.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet2.ID, + Name: wallet2.Name, + Homepage: wallet2.Homepage, + SEP10ClientDomain: wallet2.SEP10ClientDomain, + }, + StellarAddress: receiverWallet2.StellarAddress, + StellarMemo: receiverWallet2.StellarMemo, + StellarMemoType: receiverWallet2.StellarMemoType, + Status: receiverWallet2.Status, + CreatedAt: receiverWallet2.CreatedAt, + UpdatedAt: receiverWallet2.CreatedAt, + InvitedAt: &message3.CreatedAt, + LastSmsSent: &message4.CreatedAt, + ReceiverWalletStats: ReceiverWalletStats{ + TotalPayments: "1", + PaymentsReceived: "0", + FailedPayments: "0", + RemainingPayments: "1", + ReceivedAmounts: []Amount{ + { + AssetCode: "USDC", + AssetIssuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + ReceivedAmount: "0", + }, + }, + }, + }, + } + assert.Equal(t, expected, actual) + + err = dbTx.Commit() + require.NoError(t, err) + }) + + DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + + t.Run("returns receiver_wallet with session", func(t *testing.T) { + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, DraftReceiversWalletStatus) + + message1 = CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &receiverWallet.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message2 = CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &receiverWallet.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + // Initializing a new Tx. + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + actual, err := receiverWalletModel.GetWithReceiverIds(ctx, dbTx, ReceiverIDs{receiver.ID}) + require.NoError(t, err) + expected := []ReceiverWallet{ + { + ID: receiverWallet.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet1.ID, + Name: wallet1.Name, + Homepage: wallet1.Homepage, + SEP10ClientDomain: wallet1.SEP10ClientDomain, + }, + StellarAddress: receiverWallet.StellarAddress, + StellarMemo: receiverWallet.StellarMemo, + StellarMemoType: receiverWallet.StellarMemoType, + Status: receiverWallet.Status, + CreatedAt: receiverWallet.CreatedAt, + UpdatedAt: receiverWallet.CreatedAt, + InvitedAt: &message1.CreatedAt, + LastSmsSent: &message2.CreatedAt, + ReceiverWalletStats: ReceiverWalletStats{ + TotalPayments: "0", + PaymentsReceived: "0", + FailedPayments: "0", + RemainingPayments: "0", + ReceivedAmounts: nil, + }, + }, + } + assert.Equal(t, expected, actual) + + // Commit the transaction. + err = dbTx.Commit() + require.NoError(t, err) + }) +} + +func Test_GetByReceiverIDAndWalletDomain(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverWalletModel := ReceiverWalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when receiver wallet does not exist", func(t *testing.T) { + actual, errGetReceiverWallet := receiverWalletModel.GetByReceiverIDAndWalletDomain(ctx, "invalid_id", "invalid_domain", dbConnectionPool) + require.Error(t, errGetReceiverWallet, "error querying receiver wallet: sql: no rows in result set") + require.Empty(t, actual) + }) + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + // TODO update CreateReceiverWalletFixture to allow create a wallet with a ReceiverWallet object + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + query := ` + UPDATE + receiver_wallets rw + SET + otp = $1, + otp_created_at = NOW() + WHERE + rw.id = $2 + RETURNING + otp_created_at + ` + err = dbConnectionPool.GetContext(ctx, &receiverWallet.OTPCreatedAt, query, "123456", receiverWallet.ID) + require.NoError(t, err) + + t.Run("returns error when receiver wallet not found for receiver id", func(t *testing.T) { + actual, errGetReceiverWallet := receiverWalletModel.GetByReceiverIDAndWalletDomain(ctx, "invalid_id", wallet.SEP10ClientDomain, dbConnectionPool) + require.Error(t, errGetReceiverWallet, "error querying receiver wallet: sql: no rows in result set") + require.Empty(t, actual) + }) + + t.Run("returns error when receiver wallet not found with wallet domain", func(t *testing.T) { + actual, errGetReceiverWallet := receiverWalletModel.GetByReceiverIDAndWalletDomain(ctx, receiver.ID, "invalid_domain", dbConnectionPool) + require.Error(t, errGetReceiverWallet, "error querying receiver wallet: sql: no rows in result set") + require.Empty(t, actual) + }) + + t.Run("returns receiver_wallet", func(t *testing.T) { + actual, err := receiverWalletModel.GetByReceiverIDAndWalletDomain(ctx, receiver.ID, wallet.SEP10ClientDomain, dbConnectionPool) + require.NoError(t, err) + + expected := ReceiverWallet{ + ID: receiverWallet.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet.ID, + Name: wallet.Name, + SEP10ClientDomain: wallet.SEP10ClientDomain, + }, + Status: receiverWallet.Status, + OTP: "123456", + OTPCreatedAt: receiverWallet.OTPCreatedAt, + } + + assert.Equal(t, expected, *actual) + }) +} + +func Test_UpdateReceiverWallet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverWalletModel := ReceiverWalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when receiver wallet does not exist", func(t *testing.T) { + err := receiverWalletModel.UpdateReceiverWallet(ctx, ReceiverWallet{ID: "invalid_id", Status: DraftReceiversWalletStatus}, dbConnectionPool) + require.NoError(t, err) + }) + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + + t.Run("returns error when status is not valid", func(t *testing.T) { + receiverWallet.Status = "invalid_status" + err := receiverWalletModel.UpdateReceiverWallet(ctx, *receiverWallet, dbConnectionPool) + require.Error(t, err, "error querying receiver wallet: sql: no rows in result set") + }) + + t.Run("successfuly update receiver wallet", func(t *testing.T) { + receiverWallet.StellarAddress = "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444" + receiverWallet.StellarMemo = "123456" + receiverWallet.StellarMemoType = "id" + receiverWallet.Status = RegisteredReceiversWalletStatus + + err := receiverWalletModel.UpdateReceiverWallet(ctx, *receiverWallet, dbConnectionPool) + require.NoError(t, err) + + // validate if the receiver wallet has been updated + query := ` + SELECT + rw.status, + rw.stellar_address, + rw.stellar_memo, + rw.stellar_memo_type, + otp_confirmed_at + FROM + receiver_wallets rw + WHERE + rw.id = $1 + ` + receiverWalletUpdated := ReceiverWallet{} + err = dbConnectionPool.GetContext(ctx, &receiverWalletUpdated, query, receiverWallet.ID) + require.NoError(t, err) + + assert.Equal(t, RegisteredReceiversWalletStatus, receiverWalletUpdated.Status) + assert.Equal(t, "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", receiverWalletUpdated.StellarAddress) + assert.Equal(t, "123456", receiverWalletUpdated.StellarMemo) + assert.Equal(t, "id", receiverWalletUpdated.StellarMemoType) + require.NotEmpty(t, receiverWalletUpdated.OTPConfirmedAt) + }) +} + +func Test_ReceiverWallet_UpdateOTPByReceiverPhoneNumberAndWalletHomePage(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverWalletModel := ReceiverWalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns 1 updated row when the receiver wallet has not confirmed yet", func(t *testing.T) { + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "http://home.page", "home.page", "wallet1://") + _ = CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, RegisteredReceiversWalletStatus) + + testingOTP := "123456" + + rowsUpdated, err := receiverWalletModel.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, receiver1.PhoneNumber, wallet1.SEP10ClientDomain, testingOTP) + require.NoError(t, err) + assert.Equal(t, 1, rowsUpdated) + }) + + t.Run("returns 1 updated row when trying to renew an OTP with an unconfirmed receiver wallet", func(t *testing.T) { + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + receiver2 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "testWalletC", "http://home3.page", "home3.page", "wallet3://") + + rw1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, RegisteredReceiversWalletStatus) + rw2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet1.ID, RegisteredReceiversWalletStatus) + + testingOTP := "222333" + + q := ` + UPDATE + receiver_wallets + SET + otp_confirmed_at = NOW() + WHERE + id = $1 + ` + _, err := dbConnectionPool.ExecContext(ctx, q, rw1.ID) + require.NoError(t, err) + + rowsUpdated, err := receiverWalletModel.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, receiver2.PhoneNumber, wallet1.SEP10ClientDomain, testingOTP) + require.NoError(t, err) + assert.Equal(t, 1, rowsUpdated) + + q = `SELECT otp FROM receiver_wallets WHERE id = $1` + var dbOTP string + err = dbConnectionPool.QueryRowxContext(ctx, q, rw2.ID).Scan(&dbOTP) + require.NoError(t, err) + assert.Equal(t, testingOTP, dbOTP) + }) + + t.Run("returns 0 updated rows when when the receiver wallet is confirmed", func(t *testing.T) { + receiver1 := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "testWalletD", "http://home4.page", "home4.page", "wallet4://") + _ = CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, RegisteredReceiversWalletStatus) + + testingOTP := "123456" + + q := ` + UPDATE + receiver_wallets + SET + otp_confirmed_at = NOW() + ` + _, err := dbConnectionPool.ExecContext(ctx, q) + require.NoError(t, err) + + rowsUpdated, err := receiverWalletModel.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, receiver1.PhoneNumber, wallet1.SEP10ClientDomain, testingOTP) + require.NoError(t, err) + assert.Equal(t, 0, rowsUpdated) + }) +} + +func Test_VerifyReceiverWalletOTP(t *testing.T) { + ctx := context.Background() + receiverWalletModel := ReceiverWalletModel{} + + expiredOTPCreatedAt := time.Now().Add(-OTPExpirationTimeMinutes * time.Minute).Add(-time.Second) // expired 1 second ago + validOTPTime := time.Now() + + testCases := []struct { + name string + networkPassphrase string + attemptedOTP string + otp string + otpCreatedAt time.Time + wantErr error + }{ + // mismatching OTP fails: + { + name: "mismatching OTP fails", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: "123123", + otp: "123456", + otpCreatedAt: time.Time{}, // invalid created_at + wantErr: fmt.Errorf("otp does not match with value saved in the database"), + }, + { + name: "mismatching OTP fails when passing the TestnetAlwaysValidOTP in Pubnet", + networkPassphrase: network.PublicNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: "123456", + otpCreatedAt: time.Time{}, // invalid created_at + wantErr: fmt.Errorf("otp does not match with value saved in the database"), + }, + { + name: "mismatching OTP succeeds when passing the TestnetAlwaysValidOTP in Testnet", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: "123456", + otpCreatedAt: time.Time{}, // invalid created_at + wantErr: nil, + }, + + // matching OTP fails when its created_at date is invalid: + { + name: "matching OTP fails when its created_at date is invalid", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: "123456", + otp: "123456", + otpCreatedAt: time.Time{}, // invalid created_at + wantErr: fmt.Errorf("otp does not have a valid created_at time"), + }, + { + name: "matching OTP fails when its created_at date is invalid and we pass TestnetAlwaysValidOTP in Pubnet", + networkPassphrase: network.PublicNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: TestnetAlwaysValidOTP, + otpCreatedAt: time.Time{}, // invalid created_at + wantErr: fmt.Errorf("otp does not have a valid created_at time"), + }, + { + name: "matching OTP succeeds when its created_at date is invalid but we pass TestnetAlwaysValidOTP in Testnet", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: "123456", + otpCreatedAt: time.Time{}, // invalid created_at + wantErr: nil, + }, + + // returns error when otp is expired: + { + name: "matching OTP fails when OTP is expired", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: "123456", + otp: "123456", + otpCreatedAt: expiredOTPCreatedAt, + wantErr: fmt.Errorf("otp is expired"), + }, + { + name: "matching OTP fails when OTP is expired and we pass TestnetAlwaysValidOTP in Pubnet", + networkPassphrase: network.PublicNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: TestnetAlwaysValidOTP, + otpCreatedAt: expiredOTPCreatedAt, + wantErr: fmt.Errorf("otp is expired"), + }, + { + name: "matching OTP fails when OTP is expired but we pass TestnetAlwaysValidOTP in Testnet", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: "123456", + otpCreatedAt: expiredOTPCreatedAt, + wantErr: nil, + }, + + // OTP is valid πŸŽ‰ + { + name: "OTP is valid πŸŽ‰", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: "123456", + otp: "123456", + otpCreatedAt: validOTPTime, + wantErr: nil, + }, + { + name: "OTP is valid πŸŽ‰ also when we pass TestnetAlwaysValidOTP in Pubnet", + networkPassphrase: network.PublicNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: TestnetAlwaysValidOTP, + otpCreatedAt: validOTPTime, + wantErr: nil, + }, + { + name: "OTP is valid πŸŽ‰ also when we pass TestnetAlwaysValidOTP in Testnet", + networkPassphrase: network.TestNetworkPassphrase, + attemptedOTP: TestnetAlwaysValidOTP, + otp: TestnetAlwaysValidOTP, + otpCreatedAt: validOTPTime, + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + receiverWallet := ReceiverWallet{ + OTP: tc.otp, + OTPCreatedAt: &tc.otpCreatedAt, + } + err := receiverWalletModel.VerifyReceiverWalletOTP(ctx, tc.networkPassphrase, receiverWallet, tc.attemptedOTP) + if tc.wantErr != nil { + assert.Equal(t, tc.wantErr, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_ReceiverWallet_statusHistoryFromByteArray(t *testing.T) { + var receiverWallet ReceiverWallet + + t.Run("returns error when status history is invalid", func(t *testing.T) { + err := receiverWallet.statusHistoryFromByteArray(pq.ByteaArray{[]byte("invalid")}) + require.Error(t, err, "error unmarshaling status_history column:") + }) + + t.Run("returns status history successfully", func(t *testing.T) { + statusHistory := pq.ByteaArray{[]byte(`{"status": "DRAFT", "timestamp": "2023-03-11T01:20:39.363154Z"}`)} + expected := []ReceiversWalletStatusHistoryEntry{ + { + Status: DraftReceiversWalletStatus, + Timestamp: time.Date(2023, 0o3, 11, 0o1, 20, 39, 363154000, time.UTC), + }, + } + err := receiverWallet.statusHistoryFromByteArray(statusHistory) + require.NoError(t, err) + assert.Equal(t, expected, receiverWallet.StatusHistory) + }) +} + +func Test_ReceiverWallet_statusHistoryJson(t *testing.T) { + entry1 := ReceiversWalletStatusHistoryEntry{ + Status: "READY", + Timestamp: time.Now(), + } + entry2 := ReceiversWalletStatusHistoryEntry{ + Status: "REGISTERED", + Timestamp: time.Now().Add(1 * time.Hour), + } + + receiverWallet := &ReceiverWallet{ + StatusHistory: []ReceiversWalletStatusHistoryEntry{entry1, entry2}, + } + + t.Run("returns status history successfully", func(t *testing.T) { + statusHistoryJSON, err := receiverWallet.statusHistoryJson() + require.NoError(t, err) + + expectedJSON1 := `{"status":"READY","timestamp":"` + entry1.Timestamp.Format(time.RFC3339Nano) + `"}` + expectedJSON2 := `{"status":"REGISTERED","timestamp":"` + entry2.Timestamp.Format(time.RFC3339Nano) + `"}` + + assert.Equal(t, 2, len(receiverWallet.StatusHistory)) + assert.Contains(t, statusHistoryJSON, expectedJSON1) + assert.Contains(t, statusHistoryJSON, expectedJSON2) + }) +} + +func Test_ReceiverWallet_GetAllPendingRegistration(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet1 := CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet1", "https://wallet1.com", "www.wallet.com", "wallet1://") + wallet2 := CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet2", "https://wallet2.com", "www.wallet2.com", "wallet2://") + + rwm := ReceiverWalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("gets all receiver wallets pending registration when no message were sent", func(t *testing.T) { + DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + + _ = CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, DraftReceiversWalletStatus) + rw2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, ReadyReceiversWalletStatus) + + rws, err := rwm.GetAllPendingRegistration(ctx, 7, 3) + require.NoError(t, err) + + expectedRWs := []*ReceiverWallet{ + { + ID: rw2.ID, + Receiver: Receiver{ + ID: receiver.ID, + PhoneNumber: receiver.PhoneNumber, + Email: receiver.Email, + }, + Wallet: Wallet{ + ID: wallet2.ID, + Name: wallet2.Name, + }, + }, + } + + assert.Len(t, rws, 1) + assert.Equal(t, rws, expectedRWs) + }) + + t.Run("gets all receiver wallets pending registration when days since last invitation is satisfied", func(t *testing.T) { + DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + + rw1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, DraftReceiversWalletStatus) + rw2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, ReadyReceiversWalletStatus) + + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &rw1.ID, + Status: PendingMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -3).UTC(), + UpdatedAt: time.Now().UTC(), + }) + + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &rw2.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -8).UTC(), + UpdatedAt: time.Now().AddDate(0, 0, -8).UTC(), + }) + + expectedRWs := []*ReceiverWallet{ + { + ID: rw2.ID, + Receiver: Receiver{ + ID: receiver.ID, + PhoneNumber: receiver.PhoneNumber, + Email: receiver.Email, + }, + Wallet: Wallet{ + ID: wallet2.ID, + Name: wallet2.Name, + }, + }, + } + + rws, err := rwm.GetAllPendingRegistration(ctx, 6, 3) + require.NoError(t, err) + + assert.Len(t, rws, 1) + assert.Equal(t, expectedRWs, rws) + }) + + t.Run("get all receiver wallets pending registration when max tries isn't reached", func(t *testing.T) { + DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + + rw1 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, ReadyReceiversWalletStatus) + rw2 := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, ReadyReceiversWalletStatus) + + // Invitations sent for rw1 - reached max tries + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &rw1.ID, + Status: FailureMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -3).UTC(), + UpdatedAt: time.Now().AddDate(0, 0, -3).UTC(), + }) + + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &rw1.ID, + Status: PendingMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -6).UTC(), + UpdatedAt: time.Now().AddDate(0, 0, -6).UTC(), + }) + + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &rw1.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -9).UTC(), + UpdatedAt: time.Now().AddDate(0, 0, -9).UTC(), + }) + + // Invitations sent for rw2 + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &rw2.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -5).UTC(), + UpdatedAt: time.Now().AddDate(0, 0, -5).UTC(), + }) + + CreateMessageFixture(t, ctx, dbConnectionPool, &Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &rw2.ID, + Status: SuccessMessageStatus, + CreatedAt: time.Now().AddDate(0, 0, -8).UTC(), + UpdatedAt: time.Now().AddDate(0, 0, -8).UTC(), + }) + + expectedRWs := []*ReceiverWallet{ + { + ID: rw2.ID, + Receiver: Receiver{ + ID: receiver.ID, + PhoneNumber: receiver.PhoneNumber, + Email: receiver.Email, + }, + Wallet: Wallet{ + ID: wallet2.ID, + Name: wallet2.Name, + }, + }, + } + + rws, err := rwm.GetAllPendingRegistration(ctx, 3, 3) + require.NoError(t, err) + + assert.Equal(t, expectedRWs, rws) + }) +} + +func Test_GetByStellarAccountAndMemo(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + receiverWalletModel := ReceiverWalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when receiver wallet does not exist", func(t *testing.T) { + actual, innerErr := receiverWalletModel.GetByStellarAccountAndMemo(ctx, "GCRSI42IC7WSW6N46LWPAHQWFI6MLGPBN3BYQ2WMNJ43GNRTIEYCAD6O", "") + require.ErrorIs(t, innerErr, ErrRecordNotFound) + require.Empty(t, actual) + }) + + receiver := CreateReceiverFixture(t, ctx, dbConnectionPool, &Receiver{}) + wallet := CreateWalletFixture(t, ctx, dbConnectionPool, "wallet", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiverWallet := CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, DraftReceiversWalletStatus) + results, err := receiverWalletModel.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, receiver.PhoneNumber, wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + require.Equal(t, 1, results) + + t.Run("wont find the result if stellar address is provided but memo is not", func(t *testing.T) { + actual, innerErr := receiverWalletModel.GetByStellarAccountAndMemo(ctx, receiverWallet.StellarAddress, "") + require.ErrorIs(t, innerErr, ErrRecordNotFound) + require.Empty(t, actual) + }) + + t.Run("wont find the result if memo is provided but stellar address is not", func(t *testing.T) { + actual, innerErr := receiverWalletModel.GetByStellarAccountAndMemo(ctx, "", receiverWallet.StellarMemo) + require.ErrorIs(t, innerErr, ErrRecordNotFound) + require.Empty(t, actual) + }) + + t.Run("returns receiver_wallet when both stellar account and memo are provided", func(t *testing.T) { + actual, innerErr := receiverWalletModel.GetByStellarAccountAndMemo(ctx, receiverWallet.StellarAddress, receiverWallet.StellarMemo) + require.NoError(t, innerErr) + + expected := ReceiverWallet{ + ID: receiverWallet.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet.ID, + Name: wallet.Name, + Homepage: wallet.Homepage, + }, + Status: receiverWallet.Status, + OTP: "123456", + OTPCreatedAt: actual.OTPCreatedAt, + StellarAddress: receiverWallet.StellarAddress, + StellarMemo: receiverWallet.StellarMemo, + StellarMemoType: receiverWallet.StellarMemoType, + } + + assert.Equal(t, expected, *actual) + }) + + query := `UPDATE receiver_wallets SET stellar_memo = NULL, stellar_memo_type = NULL WHERE id = $1` + _, err = dbConnectionPool.ExecContext(ctx, query, receiverWallet.ID) + require.NoError(t, err) + + t.Run("returns receiver_wallet when stellar account is provided and memo is null", func(t *testing.T) { + actual, err := receiverWalletModel.GetByStellarAccountAndMemo(ctx, receiverWallet.StellarAddress, "") + require.NoError(t, err) + + expected := ReceiverWallet{ + ID: receiverWallet.ID, + Receiver: Receiver{ID: receiver.ID}, + Wallet: Wallet{ + ID: wallet.ID, + Name: wallet.Name, + Homepage: wallet.Homepage, + }, + Status: receiverWallet.Status, + OTP: "123456", + OTPCreatedAt: actual.OTPCreatedAt, + StellarAddress: receiverWallet.StellarAddress, + StellarMemo: "", + StellarMemoType: "", + } + + assert.Equal(t, expected, *actual) + }) + + t.Run("won't find a result if stellar account and memo are provided, but the DB memo is NULL", func(t *testing.T) { + actual, err := receiverWalletModel.GetByStellarAccountAndMemo(ctx, receiverWallet.StellarAddress, receiverWallet.StellarMemo) + require.ErrorIs(t, err, ErrRecordNotFound) + require.Empty(t, actual) + }) +} diff --git a/internal/data/roles.go b/internal/data/roles.go new file mode 100644 index 000000000..b24414690 --- /dev/null +++ b/internal/data/roles.go @@ -0,0 +1,46 @@ +package data + +type UserRole string + +func (u UserRole) String() string { + return string(u) +} + +func (u UserRole) IsValid() bool { + switch u { + case OwnerUserRole, FinancialControllerUserRole, DeveloperUserRole, BusinessUserRole: + return true + } + return false +} + +// Roles description reference: https://stellarfoundation.slack.com/archives/C04C9MLM9UZ/p1681238994830149 +const ( + // OwnerUserRole has permission to do everything. Also, it's in charge of creating new users and managing Org account. + OwnerUserRole UserRole = "owner" + // FinancialControllerUserRole has the same permissions as the OwnerUserRole except for user management. + FinancialControllerUserRole UserRole = "financial_controller" + // DeveloperUserRole has only configuration permissions. (wallets, assets, countries management. Also, statistics access permission) + DeveloperUserRole UserRole = "developer" + // BusinessUserRole has read-only permissions - except for user management that they can't read any data. + BusinessUserRole UserRole = "business" +) + +// GetAllRoles returns all roles available +func GetAllRoles() []UserRole { + return []UserRole{ + OwnerUserRole, + FinancialControllerUserRole, + DeveloperUserRole, + BusinessUserRole, + } +} + +// FromUserRoleArrayToStringArray converts an array of UserRole type to an array of string +func FromUserRoleArrayToStringArray(roles []UserRole) []string { + rolesString := make([]string, 0, len(roles)) + for _, role := range roles { + rolesString = append(rolesString, role.String()) + } + return rolesString +} diff --git a/internal/data/roles_test.go b/internal/data/roles_test.go new file mode 100644 index 000000000..9d7b18254 --- /dev/null +++ b/internal/data/roles_test.go @@ -0,0 +1,15 @@ +package data + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_UserRole_IsValid(t *testing.T) { + role := UserRole("unknown") + assert.False(t, role.IsValid()) + + role = UserRole("developer") + assert.True(t, role.IsValid()) +} diff --git a/internal/data/statemachine.go b/internal/data/statemachine.go new file mode 100644 index 000000000..ac82c3be9 --- /dev/null +++ b/internal/data/statemachine.go @@ -0,0 +1,46 @@ +package data + +import "fmt" + +type State string + +type StateTransition struct { + From State + To State +} + +type StateMachine struct { + CurrentState State + Transitions map[State]map[State]bool +} + +func NewStateMachine(initialState State, transitions []StateTransition) *StateMachine { + sm := &StateMachine{ + CurrentState: initialState, + Transitions: make(map[State]map[State]bool), + } + + for _, t := range transitions { + if sm.Transitions[t.From] == nil { + sm.Transitions[t.From] = make(map[State]bool) + } + sm.Transitions[t.From][t.To] = true + } + + return sm +} + +func (sm *StateMachine) CanTransitionTo(targetState State) bool { + if _, ok := sm.Transitions[sm.CurrentState]; !ok { + return false + } + return sm.Transitions[sm.CurrentState][targetState] +} + +func (sm *StateMachine) TransitionTo(targetState State) error { + if sm.CanTransitionTo(targetState) { + sm.CurrentState = targetState + return nil + } + return fmt.Errorf("cannot transition from %s to %s", sm.CurrentState, targetState) +} diff --git a/internal/data/wallets.go b/internal/data/wallets.go new file mode 100644 index 000000000..53b425689 --- /dev/null +++ b/internal/data/wallets.go @@ -0,0 +1,171 @@ +package data + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type Wallet struct { + ID string `json:"id" db:"id"` + Name string `json:"name" db:"name"` + Homepage string `json:"homepage,omitempty" db:"homepage"` + SEP10ClientDomain string `json:"sep_10_client_domain,omitempty" db:"sep_10_client_domain"` + DeepLinkSchema string `json:"deep_link_schema,omitempty" db:"deep_link_schema"` + CreatedAt *time.Time `json:"created_at,omitempty" db:"created_at"` + UpdatedAt *time.Time `json:"updated_at,omitempty" db:"updated_at"` + DeletedAt *time.Time `json:"-" db:"deleted_at"` +} + +type WalletModel struct { + dbConnectionPool db.DBConnectionPool +} + +func (w *WalletModel) Get(ctx context.Context, id string) (*Wallet, error) { + var wallet Wallet + query := ` + SELECT + w.id, + w.name, + w.homepage, + w.sep_10_client_domain, + w.deep_link_schema, + w.created_at, + w.updated_at + FROM + wallets w + WHERE + w.id = $1 + ` + + err := w.dbConnectionPool.GetContext(ctx, &wallet, query, id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying wallet ID %s: %w", id, err) + } + return &wallet, nil +} + +// GetByWalletName returns wallet filtering by wallet name. +func (w *WalletModel) GetByWalletName(ctx context.Context, name string) (*Wallet, error) { + var wallet Wallet + query := ` + SELECT + w.id, + w.name, + w.homepage, + w.sep_10_client_domain, + w.deep_link_schema, + w.created_at, + w.updated_at + FROM + wallets w + WHERE + w.name = $1 + ` + + err := w.dbConnectionPool.GetContext(ctx, &wallet, query, name) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying wallet with name %s: %w", name, err) + } + return &wallet, nil +} + +// GetAll returns all wallets in the database +func (w *WalletModel) GetAll(ctx context.Context) ([]Wallet, error) { + wallets := []Wallet{} + query := ` + SELECT + w.id, + w.name, + w.homepage, + w.sep_10_client_domain, + w.deep_link_schema, + w.created_at, + w.updated_at + FROM + wallets w + ORDER BY + name + ` + + err := w.dbConnectionPool.SelectContext(ctx, &wallets, query) + if err != nil { + return nil, fmt.Errorf("error querying wallets: %w", err) + } + return wallets, nil +} + +func (w *WalletModel) Insert(ctx context.Context, name string, homepage string, deepLink string, sep10Domain string) (*Wallet, error) { + const query = ` + INSERT INTO wallets + (name, homepage, deep_link_schema, sep_10_client_domain) + VALUES + ($1, $2, $3, $4) + RETURNING + id, + name, + homepage, + sep_10_client_domain, + deep_link_schema, + created_at, + updated_at + ` + + var wallet Wallet + err := w.dbConnectionPool.GetContext(ctx, &wallet, query, name, homepage, deepLink, sep10Domain) + if err != nil { + return nil, fmt.Errorf("error inserting wallet: %w", err) + } + + return &wallet, nil +} + +func (w *WalletModel) GetOrCreate(ctx context.Context, name, homepage, deepLink, sep10Domain string) (*Wallet, error) { + const query = ` + WITH create_wallet AS( + INSERT INTO wallets + (name, homepage, deep_link_schema, sep_10_client_domain) + VALUES + ($1, $2, $3, $4) + ON CONFLICT (name, homepage, deep_link_schema) DO NOTHING + RETURNING + id, + name, + homepage, + sep_10_client_domain, + deep_link_schema, + created_at, + updated_at + ) + SELECT * FROM create_wallet cw + UNION ALL + SELECT + id, + name, + homepage, + sep_10_client_domain, + deep_link_schema, + created_at, + updated_at + FROM wallets w + WHERE w.name = $1 + ` + + var wallet Wallet + err := w.dbConnectionPool.GetContext(ctx, &wallet, query, name, homepage, deepLink, sep10Domain) + if err != nil { + return nil, fmt.Errorf("error getting or creating wallet: %w", err) + } + + return &wallet, nil +} diff --git a/internal/data/wallets_test.go b/internal/data/wallets_test.go new file mode 100644 index 000000000..c0a2f143d --- /dev/null +++ b/internal/data/wallets_test.go @@ -0,0 +1,194 @@ +package data + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_WalletModelGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + walletModel := &WalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when wallet is not found", func(t *testing.T) { + _, err := walletModel.Get(ctx, "not-found") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns wallet successfully", func(t *testing.T) { + expected := CreateWalletFixture(t, ctx, dbConnectionPool.SqlxDB(), + "NewWallet", + "https://newwallet.com", + "newwallet.com", + "newalletapp://") + + actual, err := walletModel.Get(ctx, expected.ID) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) +} + +func Test_WalletModelGetByWalletName(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + walletModel := &WalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error when wallet is not found", func(t *testing.T) { + _, err := walletModel.GetByWalletName(ctx, "invalid name") + require.Error(t, err) + require.Equal(t, ErrRecordNotFound, err) + }) + + t.Run("returns wallet successfully", func(t *testing.T) { + expected := CreateWalletFixture(t, ctx, dbConnectionPool.SqlxDB(), + "NewWallet", + "https://newwallet.com", + "newwallet.com", + "newalletapp://") + + actual, err := walletModel.GetByWalletName(ctx, expected.Name) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) +} + +func Test_WalletModelGetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + walletModel := &WalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns all wallets successfully", func(t *testing.T) { + expected := ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool.SqlxDB()) + actual, err := walletModel.GetAll(ctx) + require.NoError(t, err) + + assert.Equal(t, expected, actual) + }) + + t.Run("returns empty array when no wallets", func(t *testing.T) { + DeleteAllWalletFixtures(t, ctx, dbConnectionPool.SqlxDB()) + actual, err := walletModel.GetAll(ctx) + require.NoError(t, err) + + assert.Equal(t, []Wallet{}, actual) + }) +} + +func Test_WalletModelInsert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + walletModel := &WalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("inserts wallet successfully", func(t *testing.T) { + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + name := "test_wallet" + homepage := "https://www.test_wallet.com" + deep_link_schema := "test_wallet://" + sep_10_client_domain := "www.test_wallet.com" + + wallet, err := walletModel.Insert(ctx, name, homepage, deep_link_schema, sep_10_client_domain) + require.NoError(t, err) + assert.NotNil(t, wallet) + + insertedWallet, err := walletModel.Get(ctx, wallet.ID) + require.NoError(t, err) + assert.NotNil(t, insertedWallet) + }) +} + +func Test_WalletModelGetOrCreate(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + walletModel := &WalletModel{dbConnectionPool: dbConnectionPool} + + t.Run("returns error wallet name already been used", func(t *testing.T) { + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + CreateWalletFixture(t, ctx, dbConnectionPool.SqlxDB(), + "test_wallet", + "https://www.new_wallet.com", + "www.new_wallet.com", + "new_wallet://") + + name := "test_wallet" + homepage := "https://www.test_wallet.com" + deep_link_schema := "test_wallet://" + sep_10_client_domain := "www.test_wallet.com" + + wallet, err := walletModel.GetOrCreate(ctx, name, homepage, deep_link_schema, sep_10_client_domain) + require.EqualError(t, err, "error getting or creating wallet: pq: duplicate key value violates unique constraint \"wallets_name_key\"") + assert.Empty(t, wallet) + }) + + t.Run("inserts wallet successfully", func(t *testing.T) { + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + name := "test_wallet" + homepage := "https://www.test_wallet.com" + deep_link_schema := "test_wallet://" + sep_10_client_domain := "www.test_wallet.com" + + wallet, err := walletModel.GetOrCreate(ctx, name, homepage, deep_link_schema, sep_10_client_domain) + require.NoError(t, err) + assert.Equal(t, "test_wallet", wallet.Name) + assert.Equal(t, "https://www.test_wallet.com", wallet.Homepage) + assert.Equal(t, "test_wallet://", wallet.DeepLinkSchema) + assert.Equal(t, "www.test_wallet.com", wallet.SEP10ClientDomain) + }) + + t.Run("returns wallet successfully", func(t *testing.T) { + DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + expected := CreateWalletFixture(t, ctx, dbConnectionPool.SqlxDB(), + "test_wallet", + "https://www.test_wallet.com", + "www.test_wallet.com", + "test_wallet://") + + name := "test_wallet" + homepage := "https://www.test_wallet.com" + deep_link_schema := "test_wallet://" + sep_10_client_domain := "www.test_wallet.com" + + wallet, err := walletModel.GetOrCreate(ctx, name, homepage, deep_link_schema, sep_10_client_domain) + require.NoError(t, err) + assert.Equal(t, expected.ID, wallet.ID) + }) +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 000000000..0cf93d234 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,162 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/jmoiron/sqlx" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" +) + +const ( + MaxDBConnIdleTime = 10 * time.Second + MaxOpenDBConns = 30 +) + +// DBConnectionPool is an interface that wraps the sqlx.DB structs methods and includes the RunInTransaction helper. +type DBConnectionPool interface { + SQLExecuter + BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) + Close() error + Ping() error + SqlDB() *sql.DB + SqlxDB() *sqlx.DB +} + +// DBConnectionPoolImplementation is a wrapper around sqlx.DB that implements DBConnectionPool. +type DBConnectionPoolImplementation struct { + *sqlx.DB +} + +func (db *DBConnectionPoolImplementation) BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) { + return db.DB.BeginTxx(ctx, opts) +} + +func (db *DBConnectionPoolImplementation) SqlDB() *sql.DB { + return db.DB.DB +} + +func (db *DBConnectionPoolImplementation) SqlxDB() *sqlx.DB { + return db.DB +} + +// RunInTransactionWithResult runs the given atomic function in an atomic database transaction and returns a result and +// an error. Boilerplate code for database transactions. +func RunInTransactionWithResult[T any](ctx context.Context, dbConnectionPool DBConnectionPool, opts *sql.TxOptions, atomicFunction func(dbTx DBTransaction) (T, error)) (result T, err error) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, opts) + if err != nil { + return *new(T), fmt.Errorf("creating db transaction for RunInTransactionWithResult: %w", err) + } + + defer func() { + DBTxRollback(ctx, dbTx, err, "rolling back transaction due to error") + }() + + result, err = atomicFunction(dbTx) + if err != nil { + return *new(T), fmt.Errorf("running atomic function in RunInTransactionWithResult: %w", err) + } + + err = dbTx.Commit() + if err != nil { + return *new(T), fmt.Errorf("committing transaction in RunInTransactionWithResult: %w", err) + } + + return result, nil +} + +// RunInTransaction runs the given atomic function in an atomic database transaction and returns an error. Boilerplate +// code for database transactions. +func RunInTransaction(ctx context.Context, dbConnectionPool DBConnectionPool, opts *sql.TxOptions, atomicFunction func(dbTx DBTransaction) error) error { + // wrap the atomic function with a function that returns nil and an error so we can call RunInTransactionWithResult + wrappedFunction := func(dbTx DBTransaction) (interface{}, error) { + return nil, atomicFunction(dbTx) + } + + _, err := RunInTransactionWithResult(ctx, dbConnectionPool, opts, wrappedFunction) + return err +} + +// make sure *DBConnectionPoolImplementation implements DBConnectionPool: +var _ DBConnectionPool = (*DBConnectionPoolImplementation)(nil) + +// DBTransaction is an interface that wraps the sqlx.Tx structs methods. +type DBTransaction interface { + SQLExecuter + Rollback() error + Commit() error +} + +// make sure *sqlx.Tx implements DBTransaction: +var _ DBTransaction = (*sqlx.Tx)(nil) + +// SQLExecuter is an interface that wraps the *sqlx.DB and *sqlx.Tx structs methods. +type SQLExecuter interface { + DriverName() string + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + sqlx.PreparerContext + sqlx.QueryerContext + Rebind(query string) string + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error +} + +// make sure *sqlx.DB implements SQLExecuter: +var _ SQLExecuter = (*sqlx.DB)(nil) + +// make sure DBConnectionPool implements SQLExecuter: +var _ SQLExecuter = (DBConnectionPool)(nil) + +// make sure *sqlx.Tx implements SQLExecuter: +var _ SQLExecuter = (*sqlx.Tx)(nil) + +// make sure DBTransaction implements SQLExecuter: +var _ SQLExecuter = (DBTransaction)(nil) + +// DBTxRollback rolls back the transaction if there is an error. +func DBTxRollback(ctx context.Context, dbTx DBTransaction, err error, logMessage string) { + if err != nil { + log.Ctx(ctx).Errorf("%s: %s", logMessage, err.Error()) + errRollBack := dbTx.Rollback() + if errRollBack != nil { + log.Ctx(ctx).Errorf("error in database transaction rollback: %s", errRollBack.Error()) + } + } +} + +// OpenDBConnectionPool opens a new database connection pool. It returns an error if it can't connect to the database. +func OpenDBConnectionPool(dataSourceName string) (DBConnectionPool, error) { + sqlxDB, err := sqlx.Open("postgres", dataSourceName) + if err != nil { + return nil, fmt.Errorf("error creating app DB connection pool: %w", err) + } + sqlxDB.SetConnMaxIdleTime(MaxDBConnIdleTime) + sqlxDB.SetMaxOpenConns(MaxOpenDBConns) + + err = sqlxDB.Ping() + if err != nil { + return nil, fmt.Errorf("error pinging app DB connection pool: %w", err) + } + + return &DBConnectionPoolImplementation{DB: sqlxDB}, nil +} + +// OpenDBConnectionPoolWithMetrics opens a new database connection pool with the monitor service. It returns an error if it can't connect to the database. +func OpenDBConnectionPoolWithMetrics(dataSourceName string, monitorService monitor.MonitorServiceInterface) (DBConnectionPool, error) { + dbConnectionPool, err := OpenDBConnectionPool(dataSourceName) + if err != nil { + return nil, fmt.Errorf("error opening a new db connection pool: %w", err) + } + + return NewDBConnectionPoolWithMetrics(dbConnectionPool, monitorService) +} + +// CloseRows closes the given rows and logs an error if it can't close them. +func CloseRows(ctx context.Context, rows *sqlx.Rows) { + if err := rows.Close(); err != nil { + log.Ctx(ctx).Errorf("Failed to close rows: %v", err) + } +} diff --git a/internal/db/db_connection_pool_with_metrics.go b/internal/db/db_connection_pool_with_metrics.go new file mode 100644 index 000000000..6782889cb --- /dev/null +++ b/internal/db/db_connection_pool_with_metrics.go @@ -0,0 +1,56 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + + "github.com/jmoiron/sqlx" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" +) + +func NewDBConnectionPoolWithMetrics(dbConnectionPool DBConnectionPool, monitorServiceInterface monitor.MonitorServiceInterface) (*DBConnectionPoolWithMetrics, error) { + sqlExec, err := NewSQLExecuterWithMetrics(dbConnectionPool, monitorServiceInterface) + if err != nil { + return nil, fmt.Errorf("error creating SQLExecuterWithMetrics: %w", err) + } + + return &DBConnectionPoolWithMetrics{ + dbConnectionPool: dbConnectionPool, + SQLExecuterWithMetrics: *sqlExec, + }, nil +} + +// DBConnectionPoolWithMetrics is a wrapper around sqlx.DB that implements DBConnectionPool with the monitoring service. +type DBConnectionPoolWithMetrics struct { + dbConnectionPool DBConnectionPool + SQLExecuterWithMetrics +} + +// make sure *DBConnectionPoolWithMetrics implements DBConnectionPool: +var _ DBConnectionPool = (*DBConnectionPoolWithMetrics)(nil) + +func (dbc *DBConnectionPoolWithMetrics) BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) { + dbTransaction, err := dbc.dbConnectionPool.BeginTxx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("error starting a new transaction: %w", err) + } + + return NewDBTransactionWithMetrics(dbTransaction, dbc.monitorServiceInterface) +} + +func (dbc *DBConnectionPoolWithMetrics) Close() error { + return dbc.dbConnectionPool.Close() +} + +func (dbc *DBConnectionPoolWithMetrics) Ping() error { + return dbc.dbConnectionPool.Ping() +} + +func (dbc *DBConnectionPoolWithMetrics) SqlDB() *sql.DB { + return dbc.dbConnectionPool.SqlDB() +} + +func (dbc *DBConnectionPoolWithMetrics) SqlxDB() *sqlx.DB { + return dbc.dbConnectionPool.SqlxDB() +} diff --git a/internal/db/db_connection_pool_with_metrics_test.go b/internal/db/db_connection_pool_with_metrics_test.go new file mode 100644 index 000000000..161563d9e --- /dev/null +++ b/internal/db/db_connection_pool_with_metrics_test.go @@ -0,0 +1,75 @@ +package db + +import ( + "context" + "database/sql" + "testing" + + "github.com/jmoiron/sqlx" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDBConnectionPoolWithMetrics_SqlxDB(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + sqlxDB := dbConnectionPoolWithMetrics.SqlxDB() + + assert.IsType(t, &sqlx.DB{}, sqlxDB) +} + +func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + sqlDB := dbConnectionPoolWithMetrics.SqlDB() + + assert.IsType(t, &sql.DB{}, sqlDB) +} + +func TestDBConnectionPoolWithMetrics_BeginTxx(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + dbTxWithMetrics, err := dbConnectionPoolWithMetrics.BeginTxx(ctx, nil) + + // Defer a rollback in case anything fails. + defer func() { + err = dbTxWithMetrics.Rollback() + require.Error(t, err, "not in transaction") + }() + require.NoError(t, err) + + assert.IsType(t, &DBTransactionWithMetrics{}, dbTxWithMetrics) + + err = dbTxWithMetrics.Commit() + require.NoError(t, err) +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 000000000..ef4fa41fd --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,40 @@ +package db + +import ( + "testing" + + "github.com/stellar/go/support/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpen_OpenDBConnectionPool(t *testing.T) { + db := dbtest.Postgres(t) + defer db.Close() + + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + assert.Equal(t, "postgres", dbConnectionPool.DriverName()) + + err = dbConnectionPool.Ping() + require.NoError(t, err) +} + +func TestOpen_OpenDBConnectionPoolWithMetrics(t *testing.T) { + db := dbtest.Postgres(t) + defer db.Close() + + mMonitorService := &monitor.MockMonitorService{} + + dbConnectionPoolWithMetrics, err := OpenDBConnectionPoolWithMetrics(db.DSN, mMonitorService) + require.NoError(t, err) + defer dbConnectionPoolWithMetrics.Close() + + assert.Equal(t, "postgres", dbConnectionPoolWithMetrics.DriverName()) + + err = dbConnectionPoolWithMetrics.Ping() + require.NoError(t, err) +} diff --git a/internal/db/db_transaction_with_metrics.go b/internal/db/db_transaction_with_metrics.go new file mode 100644 index 000000000..e802f840a --- /dev/null +++ b/internal/db/db_transaction_with_metrics.go @@ -0,0 +1,36 @@ +package db + +import ( + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" +) + +func NewDBTransactionWithMetrics(dbTransaction DBTransaction, monitorServiceInterface monitor.MonitorServiceInterface) (*DBTransactionWithMetrics, error) { + sqlExec, err := NewSQLExecuterWithMetrics(dbTransaction, monitorServiceInterface) + if err != nil { + return nil, fmt.Errorf("error creating SQLExecuterWithMetrics: %w", err) + } + + return &DBTransactionWithMetrics{ + dbTransaction: dbTransaction, + SQLExecuterWithMetrics: *sqlExec, + }, nil +} + +// DBTransactionWithMetrics is an interface that implements DBTransaction with the monitoring service. +type DBTransactionWithMetrics struct { + dbTransaction DBTransaction + SQLExecuterWithMetrics +} + +// make sure DBTransactionWithMetrics implements DBTransaction: +var _ DBTransaction = (*DBTransactionWithMetrics)(nil) + +func (dbTx *DBTransactionWithMetrics) Commit() error { + return dbTx.dbTransaction.Commit() +} + +func (dbTx *DBTransactionWithMetrics) Rollback() error { + return dbTx.dbTransaction.Rollback() +} diff --git a/internal/db/db_transaction_with_metrics_test.go b/internal/db/db_transaction_with_metrics_test.go new file mode 100644 index 000000000..c892b65c4 --- /dev/null +++ b/internal/db/db_transaction_with_metrics_test.go @@ -0,0 +1,55 @@ +package db + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stretchr/testify/require" +) + +func TestDBTransactionWithMetrics_Commit(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + ctx := context.Background() + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + dbTransactionWithMetrics, err := NewDBTransactionWithMetrics(dbTx, mMonitorService) + require.NoError(t, err) + + err = dbTransactionWithMetrics.Commit() + require.NoError(t, err) +} + +func TestDBTransactionWithMetrics_Rollback(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + ctx := context.Background() + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + + dbTransactionWithMetrics, err := NewDBTransactionWithMetrics(dbTx, mMonitorService) + require.NoError(t, err) + + err = dbTransactionWithMetrics.Rollback() + require.NoError(t, err) +} diff --git a/internal/db/dbtest/dbtest.go b/internal/db/dbtest/dbtest.go new file mode 100644 index 000000000..0042fd9ac --- /dev/null +++ b/internal/db/dbtest/dbtest.go @@ -0,0 +1,31 @@ +package dbtest + +import ( + "net/http" + "testing" + + migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/go/support/db/dbtest" + "github.com/stellar/go/support/db/schema" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/migrations" +) + +func OpenWithoutMigrations(t *testing.T) *dbtest.DB { + db := dbtest.Postgres(t) + return db +} + +func Open(t *testing.T) *dbtest.DB { + db := OpenWithoutMigrations(t) + + conn := db.Open() + defer conn.Close() + + migrateDirection := schema.MigrateUp + m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrations.FS)} + _, err := schema.Migrate(conn.DB, m, migrateDirection, 0) + if err != nil { + t.Fatal(err) + } + return db +} diff --git a/internal/db/dbtest/dbtest_test.go b/internal/db/dbtest/dbtest_test.go new file mode 100644 index 000000000..95f6213b1 --- /dev/null +++ b/internal/db/dbtest/dbtest_test.go @@ -0,0 +1,18 @@ +package dbtest + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpen(t *testing.T) { + db := Open(t) + session := db.Open() + + count := 0 + err := session.Get(&count, `SELECT COUNT(*) FROM gorp_migrations`) + require.NoError(t, err) + assert.Greater(t, count, 0) +} diff --git a/internal/db/migrate.go b/internal/db/migrate.go new file mode 100644 index 000000000..ae00e2b6d --- /dev/null +++ b/internal/db/migrate.go @@ -0,0 +1,22 @@ +package db + +import ( + "fmt" + "net/http" + + migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/migrations" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +func Migrate(dbURL string, dir migrate.MigrationDirection, count int) (int, error) { + dbConnectionPool, err := OpenDBConnectionPool(dbURL) + if err != nil { + return 0, fmt.Errorf("database URL '%s': %w", utils.TruncateString(dbURL, len(dbURL)/4), err) + } + defer dbConnectionPool.Close() + + ms := migrate.MigrationSet{} + m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrations.FS)} + return ms.ExecMax(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName(), m, dir, count) +} diff --git a/internal/db/migrate_test.go b/internal/db/migrate_test.go new file mode 100644 index 000000000..476383684 --- /dev/null +++ b/internal/db/migrate_test.go @@ -0,0 +1,86 @@ +package db + +import ( + "context" + "io/fs" + "testing" + + migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/migrations" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMigrate_upApplyOne(t *testing.T) { + db := dbtest.OpenWithoutMigrations(t) + defer db.Close() + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + n, err := Migrate(db.DSN, migrate.Up, 1) + require.NoError(t, err) + assert.Equal(t, 1, n) + + ids := []string{} + err = dbConnectionPool.SelectContext(ctx, &ids, `SELECT id FROM gorp_migrations`) + require.NoError(t, err) + wantIDs := []string{"2023-01-20.0-initial.sql"} + assert.Equal(t, wantIDs, ids) +} + +func TestMigrate_downApplyOne(t *testing.T) { + db := dbtest.OpenWithoutMigrations(t) + defer db.Close() + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + n, err := Migrate(db.DSN, migrate.Up, 2) + require.NoError(t, err) + require.Equal(t, 2, n) + + n, err = Migrate(db.DSN, migrate.Down, 1) + require.NoError(t, err) + require.Equal(t, 1, n) + + ids := []string{} + err = dbConnectionPool.SelectContext(ctx, &ids, `SELECT id FROM gorp_migrations`) + require.NoError(t, err) + wantIDs := []string{"2023-01-20.0-initial.sql"} + assert.Equal(t, wantIDs, ids) +} + +func TestMigrate_upDownAll(t *testing.T) { + db := dbtest.OpenWithoutMigrations(t) + defer db.Close() + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + // Get number of files in the migrations directory: + var count int + err = fs.WalkDir(migrations.FS, ".", func(path string, d fs.DirEntry, err error) error { + require.NoError(t, err) + if !d.IsDir() { + count++ + } + return nil + }) + require.NoError(t, err) + + n, err := Migrate(db.DSN, migrate.Up, count) + require.NoError(t, err) + require.Equal(t, count, n) + + // TODO: fix DB transactions to make sure we can migrate down all the way + migrateDownCount := count - 6 + n, err = Migrate(db.DSN, migrate.Down, migrateDownCount) + require.NoError(t, err) + require.Equal(t, migrateDownCount, n) +} diff --git a/internal/db/migrations/2023-01-20.0-initial.sql b/internal/db/migrations/2023-01-20.0-initial.sql new file mode 100644 index 000000000..21884dcac --- /dev/null +++ b/internal/db/migrations/2023-01-20.0-initial.sql @@ -0,0 +1,7 @@ +-- This migration file is intentionally empty and is a first starting point for +-- our migrations before we yet have a schema. + +-- +migrate Up + +-- +migrate Down + diff --git a/internal/db/migrations/2023-01-23.0-dump-from-sdp-v1.sql b/internal/db/migrations/2023-01-23.0-dump-from-sdp-v1.sql new file mode 100644 index 000000000..7f95aaf38 --- /dev/null +++ b/internal/db/migrations/2023-01-23.0-dump-from-sdp-v1.sql @@ -0,0 +1,372 @@ +-- This migration file is meant to reproduce the database schema from the SDP v1, so we can support users that +-- are on SDP v1 to migrate to SDP v2. + +-- +migrate Up + + +------------------------------------------------- START DJANGO MODELS ------------------------------------------------- + +-- TABLE: auth_group +CREATE TABLE IF NOT EXISTS public.auth_group ( + id SERIAL PRIMARY KEY, + name character varying(150) NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS auth_group_name_a6ea08ec_like ON public.auth_group (name varchar_pattern_ops); +ALTER INDEX auth_group_name_a6ea08ec_like RENAME TO auth_group_name_idx; + + +-- TABLE: django_content_type +CREATE TABLE IF NOT EXISTS public.django_content_type ( + id SERIAL PRIMARY KEY, + app_label character varying(100) NOT NULL, + model character varying(100) NOT NULL, + UNIQUE (app_label, model) +); + +INSERT INTO public.django_content_type VALUES (1, 'admin', 'logentry') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (2, 'auth', 'permission') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (3, 'auth', 'group') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (4, 'auth', 'user') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (5, 'contenttypes', 'contenttype') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (6, 'sessions', 'session') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (7, 'payments', 'account') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (8, 'payments', 'disbursement') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (9, 'payments', 'heartbeat') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (10, 'payments', 'payment') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (11, 'payments', 'activation') ON CONFLICT (id) DO NOTHING; +INSERT INTO public.django_content_type VALUES (12, 'payments', 'withdrawal') ON CONFLICT (id) DO NOTHING; + + +-- TABLE: auth_permission +CREATE TABLE IF NOT EXISTS public.auth_permission ( + id SERIAL PRIMARY KEY, + name character varying(255) NOT NULL, + content_type_id integer NOT NULL REFERENCES public.django_content_type (id) DEFERRABLE INITIALLY DEFERRED, + codename character varying(100) NOT NULL, + UNIQUE (content_type_id, codename) +); +CREATE INDEX IF NOT EXISTS auth_permission_content_type_id_2f476e4b ON public.auth_permission USING btree (content_type_id); +ALTER INDEX auth_permission_content_type_id_2f476e4b RENAME TO auth_permission_content_type_id_idx; + +INSERT INTO public.auth_permission VALUES (1, 'Can add log entry', 1, 'add_logentry') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (2, 'Can change log entry', 1, 'change_logentry') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (3, 'Can delete log entry', 1, 'delete_logentry') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (4, 'Can view log entry', 1, 'view_logentry') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (5, 'Can add permission', 2, 'add_permission') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (6, 'Can change permission', 2, 'change_permission') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (7, 'Can delete permission', 2, 'delete_permission') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (8, 'Can view permission', 2, 'view_permission') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (9, 'Can add group', 3, 'add_group') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (10, 'Can change group', 3, 'change_group') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (11, 'Can delete group', 3, 'delete_group') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (12, 'Can view group', 3, 'view_group') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (13, 'Can add user', 4, 'add_user') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (14, 'Can change user', 4, 'change_user') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (15, 'Can delete user', 4, 'delete_user') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (16, 'Can view user', 4, 'view_user') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (17, 'Can add content type', 5, 'add_contenttype') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (18, 'Can change content type', 5, 'change_contenttype') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (19, 'Can delete content type', 5, 'delete_contenttype') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (20, 'Can view content type', 5, 'view_contenttype') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (21, 'Can add session', 6, 'add_session') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (22, 'Can change session', 6, 'change_session') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (23, 'Can delete session', 6, 'delete_session') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (24, 'Can view session', 6, 'view_session') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (25, 'Can add account', 7, 'add_account') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (26, 'Can change account', 7, 'change_account') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (27, 'Can delete account', 7, 'delete_account') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (28, 'Can view account', 7, 'view_account') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (29, 'Can add disbursement', 8, 'add_disbursement') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (30, 'Can change disbursement', 8, 'change_disbursement') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (31, 'Can delete disbursement', 8, 'delete_disbursement') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (32, 'Can view disbursement', 8, 'view_disbursement') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (33, 'Can add heart beat', 9, 'add_heartbeat') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (34, 'Can change heart beat', 9, 'change_heartbeat') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (35, 'Can delete heart beat', 9, 'delete_heartbeat') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (36, 'Can view heart beat', 9, 'view_heartbeat') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (37, 'Can add payment', 10, 'add_payment') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (38, 'Can change payment', 10, 'change_payment') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (39, 'Can delete payment', 10, 'delete_payment') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (40, 'Can view payment', 10, 'view_payment') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (41, 'Can add activation', 11, 'add_activation') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (42, 'Can change activation', 11, 'change_activation') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (43, 'Can delete activation', 11, 'delete_activation') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (44, 'Can view activation', 11, 'view_activation') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (45, 'Can add withdrawal', 12, 'add_withdrawal') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (46, 'Can change withdrawal', 12, 'change_withdrawal') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (47, 'Can delete withdrawal', 12, 'delete_withdrawal') ON CONFLICT DO NOTHING; +INSERT INTO public.auth_permission VALUES (48, 'Can view withdrawal', 12, 'view_withdrawal') ON CONFLICT DO NOTHING; + + +-- TABLE: auth_group_permissions +CREATE TABLE IF NOT EXISTS public.auth_group_permissions ( + id BIGSERIAL PRIMARY KEY, + group_id integer NOT NULL REFERENCES public.auth_group (id) DEFERRABLE INITIALLY DEFERRED, + permission_id integer NOT NULL REFERENCES public.auth_permission (id) DEFERRABLE INITIALLY DEFERRED, + UNIQUE (group_id, permission_id) +); +CREATE INDEX IF NOT EXISTS auth_group_permissions_group_id_b120cbf9 ON public.auth_group_permissions USING btree (group_id); +ALTER INDEX auth_group_permissions_group_id_b120cbf9 RENAME TO auth_group_permissions_group_id_idx; + +CREATE INDEX IF NOT EXISTS auth_group_permissions_permission_id_84c5c92e ON public.auth_group_permissions USING btree (permission_id); +ALTER INDEX auth_group_permissions_permission_id_84c5c92e RENAME TO auth_group_permissions_permission_id_idx; + + +-- TABLE: auth_user +CREATE TABLE IF NOT EXISTS public.auth_user ( + id SERIAL PRIMARY KEY, + password character varying(128) NOT NULL, + last_login timestamp with time zone, + is_superuser boolean NOT NULL, + username character varying(150) NOT NULL, + first_name character varying(150) NOT NULL, + last_name character varying(150) NOT NULL, + email character varying(254) NOT NULL, + is_staff boolean NOT NULL, + is_active boolean NOT NULL, + date_joined timestamp with time zone NOT NULL, + UNIQUE (username) +); +CREATE INDEX IF NOT EXISTS auth_user_username_6821ab7c_like ON public.auth_user USING btree (username varchar_pattern_ops); +ALTER INDEX auth_user_username_6821ab7c_like RENAME TO auth_user_username_idx; + + +-- TABLE: auth_user_groups +CREATE TABLE IF NOT EXISTS public.auth_user_groups ( + id BIGSERIAL PRIMARY KEY, + user_id integer NOT NULL REFERENCES public.auth_user (id) DEFERRABLE INITIALLY DEFERRED, + group_id integer NOT NULL REFERENCES public.auth_group (id) DEFERRABLE INITIALLY DEFERRED, + UNIQUE (user_id, group_id) +); +CREATE INDEX IF NOT EXISTS auth_user_groups_group_id_97559544 ON public.auth_user_groups USING btree (group_id); +ALTER INDEX auth_user_groups_group_id_97559544 RENAME TO auth_user_groups_group_id_idx; + +CREATE INDEX IF NOT EXISTS auth_user_groups_user_id_6a12ed8b ON public.auth_user_groups USING btree (user_id); +ALTER INDEX auth_user_groups_user_id_6a12ed8b RENAME TO auth_user_groups_user_id_idx; + + +-- TABLE: auth_user_user_permissions +CREATE TABLE IF NOT EXISTS public.auth_user_user_permissions ( + id BIGSERIAL PRIMARY KEY, + user_id integer NOT NULL REFERENCES public.auth_user (id) DEFERRABLE INITIALLY DEFERRED, + permission_id integer NOT NULL REFERENCES public.auth_permission (id) DEFERRABLE INITIALLY DEFERRED, + UNIQUE (user_id, permission_id) +); +CREATE INDEX IF NOT EXISTS auth_user_user_permissions_permission_id_1fbb5f2c ON public.auth_user_user_permissions USING btree (permission_id); +ALTER INDEX auth_user_user_permissions_permission_id_1fbb5f2c RENAME TO auth_user_user_permissions_permission_id_idx; + +CREATE INDEX IF NOT EXISTS auth_user_user_permissions_user_id_a95ead1b ON public.auth_user_user_permissions USING btree (user_id); +ALTER INDEX auth_user_user_permissions_user_id_a95ead1b RENAME TO auth_user_user_permissions_user_id_idx; + + +-- TABLE: django_admin_log +CREATE TABLE IF NOT EXISTS public.django_admin_log ( + id SERIAL PRIMARY KEY, + action_time timestamp with time zone NOT NULL, + object_id text, + object_repr character varying(200) NOT NULL, + action_flag smallint NOT NULL, + change_message text NOT NULL, + content_type_id integer REFERENCES public.django_content_type(id) DEFERRABLE INITIALLY DEFERRED, + user_id integer NOT NULL REFERENCES public.auth_user(id) DEFERRABLE INITIALLY DEFERRED, + CONSTRAINT django_admin_log_action_flag_check CHECK ((action_flag >= 0)) +); +CREATE INDEX IF NOT EXISTS django_admin_log_content_type_id_c4bce8eb ON public.django_admin_log USING btree (content_type_id); +ALTER INDEX django_admin_log_content_type_id_c4bce8eb RENAME TO django_admin_log_content_type_id_idx; + +CREATE INDEX IF NOT EXISTS django_admin_log_user_id_c564eba6 ON public.django_admin_log USING btree (user_id); +ALTER INDEX django_admin_log_user_id_c564eba6 RENAME TO django_admin_log_user_id_idx; + + +-- TABLE: django_content_type +CREATE TABLE IF NOT EXISTS public.django_migrations ( + id BIGSERIAL PRIMARY KEY, + app character varying(255) NOT NULL, + name character varying(255) NOT NULL, + applied timestamp with time zone NOT NULL +); + +INSERT INTO public.django_migrations VALUES (1, 'contenttypes', '0001_initial', '2023-01-04 16:05:52.644099-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (2, 'auth', '0001_initial', '2023-01-04 16:05:52.711348-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (3, 'admin', '0001_initial', '2023-01-04 16:05:52.731795-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (4, 'admin', '0002_logentry_remove_auto_add', '2023-01-04 16:05:52.742003-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (5, 'admin', '0003_logentry_add_action_flag_choices', '2023-01-04 16:05:52.752853-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (6, 'contenttypes', '0002_remove_content_type_name', '2023-01-04 16:05:52.770614-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (7, 'auth', '0002_alter_permission_name_max_length', '2023-01-04 16:05:52.780492-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (8, 'auth', '0003_alter_user_email_max_length', '2023-01-04 16:05:52.791342-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (9, 'auth', '0004_alter_user_username_opts', '2023-01-04 16:05:52.802137-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (10, 'auth', '0005_alter_user_last_login_null', '2023-01-04 16:05:52.811967-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (11, 'auth', '0006_require_contenttypes_0002', '2023-01-04 16:05:52.814286-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (12, 'auth', '0007_alter_validators_add_error_messages', '2023-01-04 16:05:52.824612-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (13, 'auth', '0008_alter_user_username_max_length', '2023-01-04 16:05:52.835405-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (14, 'auth', '0009_alter_user_last_name_max_length', '2023-01-04 16:05:52.846608-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (15, 'auth', '0010_alter_group_name_max_length', '2023-01-04 16:05:52.858149-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (16, 'auth', '0011_update_proxy_permissions', '2023-01-04 16:05:52.867694-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (17, 'auth', '0012_alter_user_first_name_max_length', '2023-01-04 16:05:52.879894-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (18, 'payments', '0001_initial', '2023-01-04 16:05:52.947015-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (19, 'payments', '0002_remove_disbursement_requested_by_and_more', '2023-01-04 16:05:52.985496-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (20, 'payments', '0003_remove_disbursement_amount_and_more', '2023-01-04 16:05:53.040818-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (21, 'payments', '0004_account_link_last_sent_at', '2023-01-04 16:05:53.051175-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (22, 'payments', '0005_account_date_of_birth_account_email_and_more', '2023-01-04 16:05:53.079481-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (23, 'payments', '0006_payment_idempotency_key_alter_account_status', '2023-01-04 16:05:53.09219-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (24, 'payments', '0007_rename_hashed_date_of_birth_account_hashed_extra_info_and_more', '2023-01-04 16:05:53.108973-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (25, 'payments', '0008_activation', '2023-01-04 16:05:53.132048-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (26, 'payments', '0009_add_yubikey_validation_service', '2023-01-04 16:05:53.146663-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (27, 'payments', '0010_alter_account_phone_number', '2023-01-04 16:05:53.155158-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (28, 'payments', '0011_alter_payment_status', '2023-01-04 16:05:53.160529-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (29, 'payments', '0012_alter_payment_amount', '2023-01-04 16:05:53.173929-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (30, 'payments', '0013_withdrawal_alter_account_status_and_more', '2023-01-04 16:05:53.211413-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (31, 'payments', '0014_payment_withdrawal_amount_payment_withdrawal_status', '2023-01-04 16:05:53.228509-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (32, 'payments', '0015_rename_stellar_transaction_id_withdrawal_sep24_transaction_id', '2023-01-04 16:05:53.237819-08') ON CONFLICT DO NOTHING; +INSERT INTO public.django_migrations VALUES (33, 'sessions', '0001_initial', '2023-01-04 16:05:53.247677-08') ON CONFLICT DO NOTHING; + + +-- TABLE: django_session +CREATE TABLE IF NOT EXISTS public.django_session ( + session_key character varying(40) NOT NULL PRIMARY KEY, + session_data text NOT NULL, + expire_date timestamp with time zone NOT NULL +); +CREATE INDEX IF NOT EXISTS django_session_expire_date_a5c62663 ON public.django_session USING btree (expire_date); +ALTER INDEX django_session_expire_date_a5c62663 RENAME TO django_session_expire_date_idx; + +CREATE INDEX IF NOT EXISTS django_session_session_key_c0390e0f_like ON public.django_session USING btree (session_key varchar_pattern_ops); +ALTER INDEX django_session_session_key_c0390e0f_like RENAME TO django_session_session_key_idx; + +------------------------------------------------- FINISH DJANGO MODELS ------------------------------------------------- + + +------------------------------------------------- START OF SDP MODELS ------------------------------------------------- + +-- TABLE: receiver (previously known as payments_account) +CREATE TABLE IF NOT EXISTS public.payments_account ( + id character varying(64) NOT NULL PRIMARY KEY, + public_key character varying(128), + registered_at timestamp with time zone NOT NULL, + phone_number character varying(32) NOT NULL, + public_key_registered_at timestamp with time zone, + status character varying(32) NOT NULL, + link_last_sent_at timestamp with time zone, + email character varying(254), + email_registered_at timestamp with time zone, + hashed_extra_info character varying(64) NOT NULL, + hashed_phone_number character varying(64) NOT NULL, + extra_info character varying(64) NOT NULL, + UNIQUE (phone_number) +); +CREATE INDEX IF NOT EXISTS payments_ac_hashed__f9420c_idx ON public.payments_account USING btree (hashed_phone_number, hashed_extra_info); +ALTER INDEX payments_ac_hashed__f9420c_idx RENAME TO receiver_hashed_phone_and_hashed_extra_info_idx; + +CREATE INDEX IF NOT EXISTS payments_account_phone_number_221a9f17_like ON public.payments_account USING btree (phone_number varchar_pattern_ops); +ALTER INDEX payments_account_phone_number_221a9f17_like RENAME TO receiver_phone_number_idx; + +CREATE INDEX IF NOT EXISTS payments_ac_registe_104353_idx ON public.payments_account USING btree (registered_at DESC); +ALTER INDEX payments_ac_registe_104353_idx RENAME TO receiver_registered_at_idx; + +ALTER TABLE payments_account RENAME TO receivers; + + +-- TABLE: on_off_switch (previously known as payments_activation) +CREATE TABLE IF NOT EXISTS public.payments_activation ( + id BIGSERIAL PRIMARY KEY, + is_active boolean NOT NULL, + last_set_at timestamp with time zone NOT NULL DEFAULT NOW() +); +INSERT INTO public.payments_activation VALUES (1, true, NOW()) ON CONFLICT DO NOTHING; + +ALTER TABLE payments_activation RENAME TO on_off_switch; + + +-- TABLE: disbursement (previously known as payments_disbursement) +CREATE TABLE IF NOT EXISTS public.payments_disbursement ( + id character varying(64) NOT NULL PRIMARY KEY, + requested_at timestamp with time zone NOT NULL +); +CREATE INDEX IF NOT EXISTS payments_di_request_16523d_idx ON public.payments_disbursement USING btree (requested_at DESC); +ALTER INDEX payments_di_request_16523d_idx RENAME TO disbursement_request_16523d_idx; + +ALTER TABLE payments_disbursement RENAME TO disbursements; + + +-- TABLE: payments_semaphore (previously known as payments_heartbeat) +CREATE TABLE IF NOT EXISTS public.payments_heartbeat ( + id BIGSERIAL PRIMARY KEY, + name character varying(128) NOT NULL, + last_beat timestamp with time zone NOT NULL +); +ALTER TABLE payments_heartbeat RENAME TO payments_semaphore; + + +-- TABLE: payment (previously known as payments_payment) +CREATE TABLE IF NOT EXISTS public.payments_payment ( + id character varying(64) NOT NULL PRIMARY KEY, + stellar_transaction_id character varying(64), + custodial_payment_id text, + status character varying(32) NOT NULL, + status_message text, + requested_at timestamp with time zone NOT NULL, + started_at timestamp with time zone, + completed_at timestamp with time zone, + account_id character varying(64) NOT NULL REFERENCES public.receivers(id) DEFERRABLE INITIALLY DEFERRED, + disbursement_id character varying(64) NOT NULL REFERENCES public.disbursements(id) DEFERRABLE INITIALLY DEFERRED, + amount numeric(7,2) NOT NULL, + idempotency_key character varying(64) NOT NULL, + withdrawal_amount numeric(7,2) NOT NULL, + withdrawal_status character varying(32) NOT NULL +); +CREATE INDEX IF NOT EXISTS payments_pa_request_4ce797_idx ON public.payments_payment USING btree (requested_at DESC); +ALTER INDEX payments_pa_request_4ce797_idx RENAME TO payment_requested_at_idx; + +CREATE INDEX IF NOT EXISTS payments_payment_account_id_af225a32 ON public.payments_payment USING btree (account_id); +ALTER INDEX payments_payment_account_id_af225a32 RENAME TO payment_account_id_idx; + +CREATE INDEX IF NOT EXISTS payments_payment_account_id_af225a32_like ON public.payments_payment USING btree (account_id varchar_pattern_ops); +ALTER INDEX payments_payment_account_id_af225a32_like RENAME TO payment_account_id_like_idx; + +CREATE INDEX IF NOT EXISTS payments_payment_disbursement_id_2a817b83 ON public.payments_payment USING btree (disbursement_id); +ALTER INDEX payments_payment_disbursement_id_2a817b83 RENAME TO payment_disbursement_id_idx; + +ALTER TABLE payments_payment RENAME TO payments; + + +-- TABLE: payments_withdrawal +CREATE TABLE IF NOT EXISTS public.payments_withdrawal ( + sep24_transaction_id character varying(64) NOT NULL PRIMARY KEY, + anchor_id character varying(64) NOT NULL, + amount numeric(7,2) NOT NULL, + started_at timestamp with time zone NOT NULL, + completed_at timestamp with time zone NOT NULL, + created_at timestamp with time zone NOT NULL, + account_id character varying(64) NOT NULL REFERENCES public.receivers (id) DEFERRABLE INITIALLY DEFERRED +); +CREATE INDEX IF NOT EXISTS payments_wi_created_18b04a_idx ON public.payments_withdrawal USING btree (created_at DESC); +ALTER INDEX payments_wi_created_18b04a_idx RENAME TO withdrawal_created_at_idx; + +CREATE INDEX IF NOT EXISTS payments_withdrawal_account_id_ec0819dd ON public.payments_withdrawal USING btree (account_id); +ALTER INDEX payments_withdrawal_account_id_ec0819dd RENAME TO withdrawal_account_id_idx; + +CREATE INDEX IF NOT EXISTS payments_withdrawal_account_id_ec0819dd_like ON public.payments_withdrawal USING btree (account_id varchar_pattern_ops); +ALTER INDEX payments_withdrawal_account_id_ec0819dd_like RENAME TO withdrawal_account_id_like_idx; + +ALTER TABLE payments_withdrawal RENAME TO withdrawal; + + +-- +migrate Down + +DROP TABLE IF EXISTS public.withdrawal CASCADE; -- Called 'payments_withdrawal' in SDP-v1 +DROP TABLE IF EXISTS public.payments CASCADE; -- Called 'payments_payment' in SDP-v1 +DROP TABLE IF EXISTS public.payments_semaphore CASCADE; -- Called 'payments_heartbeat' in SDP-v1 +DROP TABLE IF EXISTS public.disbursements CASCADE; -- Called 'payments_disbursement' in SDP-v1 +DROP TABLE IF EXISTS public.on_off_switch CASCADE; -- Called 'payments_activation' in SDP-v1 +DROP TABLE IF EXISTS public.receivers CASCADE; -- Called 'payments_account' in SDP-v1 + +DROP TABLE IF EXISTS public.django_session CASCADE; +DROP TABLE IF EXISTS public.django_migrations CASCADE; +DROP TABLE IF EXISTS public.django_admin_log CASCADE; +DROP TABLE IF EXISTS public.auth_user_user_permissions CASCADE; +DROP TABLE IF EXISTS public.auth_user_groups CASCADE; +DROP TABLE IF EXISTS public.auth_user CASCADE; +DROP TABLE IF EXISTS public.auth_group_permissions CASCADE; +DROP TABLE IF EXISTS public.auth_permission CASCADE; +DROP TABLE IF EXISTS public.django_content_type CASCADE; +DROP TABLE IF EXISTS public.auth_group CASCADE; diff --git a/internal/db/migrations/2023-01-26.0-delete-all-django-stuff.sql b/internal/db/migrations/2023-01-26.0-delete-all-django-stuff.sql new file mode 100644 index 000000000..54c57df6e --- /dev/null +++ b/internal/db/migrations/2023-01-26.0-delete-all-django-stuff.sql @@ -0,0 +1,33 @@ +-- This migration dumps all django-related stuff that was in the database of the SDP v1. + + +-- +migrate Up + +DROP TABLE IF EXISTS public.django_session CASCADE; +DROP TABLE IF EXISTS public.django_migrations CASCADE; +DROP TABLE IF EXISTS public.django_admin_log CASCADE; +DROP TABLE IF EXISTS public.auth_user_user_permissions CASCADE; +DROP TABLE IF EXISTS public.auth_user_groups CASCADE; +DROP TABLE IF EXISTS public.auth_user CASCADE; +DROP TABLE IF EXISTS public.auth_group_permissions CASCADE; +DROP TABLE IF EXISTS public.auth_permission CASCADE; +DROP TABLE IF EXISTS public.django_content_type CASCADE; +DROP TABLE IF EXISTS public.auth_group CASCADE; +DROP TABLE IF EXISTS public.otp_static_staticdevice CASCADE; +DROP TABLE IF EXISTS public.otp_static_statictoken CASCADE; +DROP TABLE IF EXISTS public.otp_totp_totpdevice CASCADE; +DROP TABLE IF EXISTS public.otp_yubikey_remoteyubikeydevice CASCADE; +DROP TABLE IF EXISTS public.otp_yubikey_validationservice CASCADE; +DROP TABLE IF EXISTS public.otp_yubikey_yubikeydevice CASCADE; +DROP TABLE IF EXISTS public.two_factor_phonedevice CASCADE; + +DROP SEQUENCE IF EXISTS public.otp_static_staticdevice_id_seq CASCADE; +DROP SEQUENCE IF EXISTS public.otp_static_statictoken_id_seq CASCADE; +DROP SEQUENCE IF EXISTS public.otp_totp_totpdevice_id_seq CASCADE; +DROP SEQUENCE IF EXISTS public.otp_yubikey_remoteyubikeydevice_id_seq CASCADE; +DROP SEQUENCE IF EXISTS public.otp_yubikey_validationservice_id_seq CASCADE; +DROP SEQUENCE IF EXISTS public.otp_yubikey_yubikeydevice_id_seq CASCADE; + + +-- +migrate Down + diff --git a/internal/db/migrations/2023-01-26.1-drop-unused-sdp-v1-tables.sql b/internal/db/migrations/2023-01-26.1-drop-unused-sdp-v1-tables.sql new file mode 100644 index 000000000..402f48f70 --- /dev/null +++ b/internal/db/migrations/2023-01-26.1-drop-unused-sdp-v1-tables.sql @@ -0,0 +1,12 @@ +-- This migration dumps all django-related stuff that was in the database of the SDP v1. + + +-- +migrate Up + +DROP TABLE IF EXISTS public.on_off_switch CASCADE; +DROP TABLE IF EXISTS public.payments_semaphore CASCADE; +DROP TABLE IF EXISTS public.withdrawal CASCADE; + + +-- +migrate Down + diff --git a/internal/db/migrations/2023-01-26.2-updated-at-trigger.sql b/internal/db/migrations/2023-01-26.2-updated-at-trigger.sql new file mode 100644 index 000000000..5d6369cb1 --- /dev/null +++ b/internal/db/migrations/2023-01-26.2-updated-at-trigger.sql @@ -0,0 +1,18 @@ +-- Add function used to refresh the updated_at column automatically. + +-- +migrate Up + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION update_at_refresh() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ language 'plpgsql'; +-- +migrate StatementEnd + + +-- +migrate Down + +DROP FUNCTION update_at_refresh; diff --git a/internal/db/migrations/2023-01-27.0-create-assets-table.sql b/internal/db/migrations/2023-01-27.0-create-assets-table.sql new file mode 100644 index 000000000..42da65ce6 --- /dev/null +++ b/internal/db/migrations/2023-01-27.0-create-assets-table.sql @@ -0,0 +1,45 @@ +-- This creates the assets table and updates the other tables that depend on it. + +-- +migrate Up + +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +CREATE TABLE public.assets ( + id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(), + code VARCHAR(12) NOT NULL, + issuer VARCHAR(56) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMP WITH TIME ZONE, + UNIQUE (code, issuer), + CONSTRAINT asset_issuer_length_check CHECK (char_length(issuer) = 56) +); +INSERT INTO public.assets (code, issuer) VALUES ('USDC', 'GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5'); + +ALTER TABLE public.disbursements + ADD COLUMN asset_id VARCHAR(36), + ADD CONSTRAINT fk_disbursement_asset_id FOREIGN KEY (asset_id) REFERENCES public.assets (id); +UPDATE public.disbursements SET asset_id = (SELECT id FROM public.assets WHERE code = 'USDC' AND issuer = 'GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5'); +ALTER TABLE public.disbursements ALTER COLUMN asset_id SET NOT NULL; + +ALTER TABLE public.payments + ADD COLUMN asset_id VARCHAR(36), + ADD CONSTRAINT fk_payment_asset_id FOREIGN KEY (asset_id) REFERENCES public.assets (id); +UPDATE public.payments SET asset_id = (SELECT id FROM public.assets WHERE code = 'USDC' AND issuer = 'GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5'); +ALTER TABLE public.payments ALTER COLUMN asset_id SET NOT NULL; + +-- TRIGGER: updated_at +CREATE TRIGGER refresh_asset_updated_at BEFORE UPDATE ON public.assets FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down + +DROP TRIGGER refresh_asset_updated_at ON public.assets; + +ALTER TABLE public.payments DROP COLUMN asset_id; + +ALTER TABLE public.disbursements DROP COLUMN asset_id; + +DROP TABLE public.assets CASCADE; + +DROP EXTENSION IF EXISTS "uuid-ossp"; diff --git a/internal/db/migrations/2023-01-27.1-create-countries-table.sql b/internal/db/migrations/2023-01-27.1-create-countries-table.sql new file mode 100644 index 000000000..0b818532c --- /dev/null +++ b/internal/db/migrations/2023-01-27.1-create-countries-table.sql @@ -0,0 +1,29 @@ +-- This creates the countries table and updates the other tables that depend on it. + +-- +migrate Up + +CREATE TABLE public.countries ( + code VARCHAR(3) PRIMARY KEY, + name VARCHAR(100) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMP WITH TIME ZONE, + UNIQUE (name), + CONSTRAINT country_code_length_check CHECK (char_length(code) = 3) +); +INSERT INTO public.countries (code, name) VALUES ('UKR', 'Ukraine'); + +ALTER TABLE public.disbursements + ADD COLUMN country_code VARCHAR(3), + ADD CONSTRAINT fk_disbursement_country_code FOREIGN KEY (country_code) REFERENCES public.countries (code); +UPDATE public.disbursements SET country_code = 'UKR'; +ALTER TABLE public.disbursements ALTER COLUMN country_code SET NOT NULL; + +CREATE TRIGGER refresh_country_updated_at BEFORE UPDATE ON public.countries FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + +-- +migrate Down +DROP TRIGGER refresh_country_updated_at ON public.countries; + +ALTER TABLE public.disbursements DROP COLUMN country_code; + +DROP TABLE public.countries CASCADE; diff --git a/internal/db/migrations/2023-01-27.2-create-wallets-table.sql b/internal/db/migrations/2023-01-27.2-create-wallets-table.sql new file mode 100644 index 000000000..5ad0a77e9 --- /dev/null +++ b/internal/db/migrations/2023-01-27.2-create-wallets-table.sql @@ -0,0 +1,34 @@ +-- This creates the wallets table and updates the other tables that depend on it. + +-- +migrate Up + +CREATE TABLE public.wallets ( + id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(), + name VARCHAR(30) NOT NULL, + homepage VARCHAR(255) NOT NULL, + deep_link_schema VARCHAR(30) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMP WITH TIME ZONE, + UNIQUE (name), + UNIQUE (homepage), + UNIQUE (deep_link_schema) +); +-- TODO: keep in mind that the deep link `vibrantapp://` is not confirmed yet and is subject to change. +INSERT INTO public.wallets (name, homepage, deep_link_schema) VALUES ('Vibrant Assist', 'https://vibrantapp.com', 'https://vibrantapp.com/sdp-dev'); + +ALTER TABLE public.disbursements + ADD COLUMN wallet_id VARCHAR(36), + ADD CONSTRAINT fk_disbursement_wallet_id FOREIGN KEY (wallet_id) REFERENCES public.wallets (id); +UPDATE public.disbursements SET wallet_id = (SELECT id FROM public.wallets WHERE name = 'Vibrant Assist'); +ALTER TABLE public.disbursements ALTER COLUMN wallet_id SET NOT NULL; + +CREATE TRIGGER refresh_wallet_updated_at BEFORE UPDATE ON public.wallets FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down +DROP TRIGGER refresh_wallet_updated_at ON public.wallets; + +ALTER TABLE public.disbursements DROP COLUMN wallet_id; + +DROP TABLE public.wallets CASCADE; diff --git a/internal/db/migrations/2023-01-27.3-create-receiver-wallets-table.sql b/internal/db/migrations/2023-01-27.3-create-receiver-wallets-table.sql new file mode 100644 index 000000000..1784b5190 --- /dev/null +++ b/internal/db/migrations/2023-01-27.3-create-receiver-wallets-table.sql @@ -0,0 +1,37 @@ +-- This creates the receiver_wallets table and updates the other tables that depend on it. + +-- +migrate Up + +-- Table: receiver_wallets +CREATE TABLE public.receiver_wallets ( + id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(), + receiver_id VARCHAR(36) NOT NULL REFERENCES public.receivers (id), + wallet_id VARCHAR(36) REFERENCES public.wallets (id), + stellar_address VARCHAR(56), + stellar_memo VARCHAR(56), + stellar_memo_type VARCHAR(56), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + UNIQUE (receiver_id, wallet_id) +); +INSERT + INTO receiver_wallets (receiver_id, stellar_address) + (SELECT id, public_key FROM receivers WHERE public_key IS NOT NULL); +UPDATE public.receiver_wallets SET wallet_id = (SELECT id FROM public.wallets WHERE name = 'Vibrant Assist'); +ALTER TABLE public.receiver_wallets ALTER COLUMN wallet_id SET NOT NULL; + +-- Table: receivers +ALTER TABLE public.receivers DROP COLUMN public_key; + +CREATE TRIGGER refresh_receiver_wallet_updated_at BEFORE UPDATE ON public.receiver_wallets FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down +DROP TRIGGER refresh_receiver_wallet_updated_at ON public.receiver_wallets; + +-- Table: receivers +ALTER TABLE public.receivers ADD COLUMN public_key VARCHAR(128); +UPDATE public.receivers SET public_key = (SELECT stellar_address FROM public.receiver_wallets WHERE receiver_id = public.receivers.id); + +-- Table: receiver_wallets +DROP TABLE public.receiver_wallets CASCADE; diff --git a/internal/db/migrations/2023-01-27.4-create-messages-table.sql b/internal/db/migrations/2023-01-27.4-create-messages-table.sql new file mode 100644 index 000000000..e31bac02d --- /dev/null +++ b/internal/db/migrations/2023-01-27.4-create-messages-table.sql @@ -0,0 +1,27 @@ +-- This creates the messages table and updates the other tables that depend on it. + +-- +migrate Up + +CREATE TYPE message_type AS ENUM( + 'TWILIO_SMS', + 'AWS_SMS', + 'AWS_EMAIL' +); + +-- Table: messages +CREATE TABLE public.messages ( + id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(), + type message_type NOT NULL, + asset_id VARCHAR(36) NOT NULL REFERENCES public.assets (id), + wallet_id VARCHAR(36) NOT NULL REFERENCES public.wallets (id), + receiver_id VARCHAR(36) NOT NULL REFERENCES public.receivers (id), + text_encrypted VARCHAR(1024) NOT NULL, + title_encrypted VARCHAR(128), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +-- +migrate Down + +-- Table: messages +DROP TABLE public.messages CASCADE; +DROP TYPE message_type; diff --git a/internal/db/migrations/2023-01-30.0-update-disbursements-table.sql b/internal/db/migrations/2023-01-30.0-update-disbursements-table.sql new file mode 100644 index 000000000..077544b4a --- /dev/null +++ b/internal/db/migrations/2023-01-30.0-update-disbursements-table.sql @@ -0,0 +1,72 @@ +-- Update the disbursements table. + +-- +migrate Up + +CREATE TYPE disbursement_status AS ENUM( + 'DRAFT', + 'READY', + 'STARTED', + 'PAUSED', + 'COMPLETED' +); + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION create_disbursement_status_history(time_stamp TIMESTAMP WITH TIME ZONE, disb_status disbursement_status, user_id VARCHAR) +RETURNS jsonb AS $$ + BEGIN + RETURN json_build_object( + 'timestamp', time_stamp, + 'status', disb_status, + 'user_id', user_id + ); + END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +ALTER TABLE public.disbursements + ALTER COLUMN id SET DEFAULT uuid_generate_v4(), + ADD COLUMN name VARCHAR(128), + ADD COLUMN status disbursement_status NOT NULL DEFAULT disbursement_status('DRAFT'), + ADD COLUMN status_history jsonb[] NOT NULL DEFAULT ARRAY[create_disbursement_status_history(NOW(), disbursement_status('DRAFT'), NULL)], + ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(); +-- TODO: Add column `uploaded_by_user_id` to disbursement table + +ALTER TABLE public.disbursements RENAME COLUMN requested_at TO created_at; + +-- columns name & id +UPDATE public.disbursements SET name = id; +ALTER TABLE public.disbursements + ALTER COLUMN created_at SET DEFAULT NOW(), + ALTER COLUMN name SET NOT NULL, + ADD CONSTRAINT disbursement_name_unique UNIQUE (name); + +-- column status +UPDATE public.disbursements AS d + SET status = (CASE + WHEN EXISTS(SELECT 1 FROM payments WHERE disbursement_id = d.id AND status != 'SUCCESS') THEN disbursement_status('STARTED') + ELSE disbursement_status('COMPLETED') + END); + +-- column updated_at +CREATE TRIGGER refresh_disbursement_updated_at BEFORE UPDATE ON public.disbursements FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + +-- column status_history +UPDATE public.disbursements SET status_history = ARRAY[create_disbursement_status_history(created_at::TIMESTAMP, disbursement_status('STARTED'), NULL)]; +UPDATE public.disbursements SET status_history = array_prepend(create_disbursement_status_history(NOW(), disbursement_status('COMPLETED'), NULL), status_history) WHERE status = disbursement_status('COMPLETED'); + + +-- +migrate Down +DROP TRIGGER refresh_disbursement_updated_at ON public.disbursements; + +ALTER TABLE public.disbursements + DROP CONSTRAINT disbursement_name_unique, + DROP COLUMN name, + DROP COLUMN status, + DROP COLUMN status_history, + DROP COLUMN updated_at; + +ALTER TABLE public.disbursements RENAME COLUMN created_at TO requested_at; + +DROP FUNCTION create_disbursement_status_history; + +DROP TYPE disbursement_status; \ No newline at end of file diff --git a/internal/db/migrations/2023-01-30.1-update-payments-table.sql b/internal/db/migrations/2023-01-30.1-update-payments-table.sql new file mode 100644 index 000000000..4496858c5 --- /dev/null +++ b/internal/db/migrations/2023-01-30.1-update-payments-table.sql @@ -0,0 +1,86 @@ +-- Update the payments table. + +-- +migrate Up + +CREATE TYPE payment_status AS ENUM( + 'DRAFT', + 'READY', + 'PENDING', + 'PAUSED', + 'SUCCESS', + 'FAILURE' +); + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION create_payment_status_history(time_stamp TIMESTAMP WITH TIME ZONE, pay_status payment_status, status_message VARCHAR) +RETURNS jsonb AS $$ + BEGIN + RETURN json_build_object( + 'timestamp', time_stamp, + 'status', pay_status, + 'status_message', status_message + ); + END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +ALTER TABLE public.payments RENAME COLUMN requested_at TO created_at; +ALTER TABLE public.payments RENAME COLUMN account_id TO receiver_id; +ALTER TABLE public.payments RENAME COLUMN status TO old_status; +ALTER TABLE public.payments + ALTER COLUMN id SET DEFAULT uuid_generate_v4(), + ALTER COLUMN amount TYPE numeric(19, 7), + ALTER COLUMN status_message TYPE VARCHAR(256), + ALTER COLUMN created_at SET DEFAULT NOW(), + DROP COLUMN custodial_payment_id, + DROP COLUMN idempotency_key, + DROP COLUMN withdrawal_amount, + DROP COLUMN withdrawal_status, + ADD COLUMN stellar_operation_id VARCHAR(32), + ADD COLUMN blockchain_sender_id VARCHAR(69), + ADD COLUMN status payment_status NOT NULL DEFAULT payment_status('DRAFT'), + ADD COLUMN status_history jsonb[] NOT NULL DEFAULT ARRAY[create_payment_status_history(NOW(), payment_status('DRAFT'), NULL)], + ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(); + +-- column status +UPDATE public.payments AS d + SET status = (CASE + WHEN old_status='REQUESTED' THEN payment_status('READY') + WHEN old_status='PENDING' THEN payment_status('PENDING') + WHEN old_status='PENDING_FUNDS' THEN payment_status('FAILURE') + WHEN old_status='SUCCESS' THEN payment_status('SUCCESS') + ELSE payment_status('FAILURE') + END); + +-- column status_history +UPDATE public.payments SET status_history = ARRAY[create_payment_status_history(created_at::TIMESTAMP, payment_status('READY'), NULL)]; +UPDATE public.payments SET status_history = array_prepend(create_payment_status_history(started_at::TIMESTAMP, payment_status('PENDING'), NULL), status_history) WHERE started_at IS NOT NULL; +UPDATE public.payments SET status_history = array_prepend(create_payment_status_history(NOW(), payment_status('FAILURE'), status_message::VARCHAR), status_history) WHERE old_status='PENDING_FUNDS'; +UPDATE public.payments SET status_history = array_prepend(create_payment_status_history(completed_at::TIMESTAMP, payment_status('SUCCESS'), NULL), status_history) WHERE old_status='SUCCESS'; +UPDATE public.payments SET status_history = array_prepend(create_payment_status_history(completed_at::TIMESTAMP, payment_status('FAILURE'), status_message::VARCHAR), status_history) WHERE old_status='FAILURE'; + +-- column updated_at +CREATE TRIGGER refresh_payment_updated_at BEFORE UPDATE ON public.payments FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down +DROP TRIGGER refresh_payment_updated_at ON public.payments; + +ALTER TABLE public.payments + ADD COLUMN custodial_payment_id VARCHAR(36), + ADD COLUMN idempotency_key VARCHAR(64), + ADD COLUMN withdrawal_amount NUMERIC(7,2) NOT NULL DEFAULT 0, + ADD COLUMN withdrawal_status VARCHAR(32), + DROP COLUMN updated_at, + DROP COLUMN status, + DROP COLUMN status_history, + DROP COLUMN stellar_operation_id, + DROP COLUMN blockchain_sender_id; + +ALTER TABLE public.payments RENAME COLUMN old_status TO status; +ALTER TABLE public.payments RENAME COLUMN created_at TO requested_at; +ALTER TABLE public.payments RENAME COLUMN receiver_id TO account_id; + +DROP FUNCTION create_payment_status_history; + +DROP TYPE payment_status; \ No newline at end of file diff --git a/internal/db/migrations/2023-01-30.2-drop-unused-payments-columns.sql b/internal/db/migrations/2023-01-30.2-drop-unused-payments-columns.sql new file mode 100644 index 000000000..060a5db44 --- /dev/null +++ b/internal/db/migrations/2023-01-30.2-drop-unused-payments-columns.sql @@ -0,0 +1,18 @@ +-- Update the payments table. + +-- +migrate Up + +ALTER TABLE public.payments + DROP COLUMN started_at, + DROP COLUMN completed_at, + DROP COLUMN old_status, + DROP COLUMN status_message; + + +-- +migrate Down + +ALTER TABLE public.payments + ADD COLUMN started_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN completed_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN old_status VARCHAR(16), + ADD COLUMN status_message VARCHAR(256); \ No newline at end of file diff --git a/internal/db/migrations/2023-01-30.3-update-receivers-table.sql b/internal/db/migrations/2023-01-30.3-update-receivers-table.sql new file mode 100644 index 000000000..28c704f67 --- /dev/null +++ b/internal/db/migrations/2023-01-30.3-update-receivers-table.sql @@ -0,0 +1,43 @@ +-- Update the receiver table. + +-- +migrate Up + +ALTER TABLE public.receivers RENAME COLUMN registered_at TO created_at; +ALTER TABLE public.receivers + ALTER COLUMN id SET DEFAULT uuid_generate_v4(), + ALTER COLUMN created_at SET DEFAULT NOW(), + DROP COLUMN link_last_sent_at, + DROP COLUMN email_registered_at, + DROP COLUMN public_key_registered_at, + DROP COLUMN hashed_extra_info, + DROP COLUMN hashed_phone_number, + ADD COLUMN encrypted_pii jsonb, + ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(); + +-- COLUMNS encrypted_pii & extra_info +UPDATE public.receivers SET encrypted_pii = json_build_object('date_of_birth', extra_info); +ALTER TABLE public.receivers ALTER COLUMN encrypted_pii SET NOT NULL; +ALTER TABLE public.receivers DROP COLUMN extra_info; + +-- column updated_at +CREATE TRIGGER refresh_receiver_updated_at BEFORE UPDATE ON public.receivers FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down +DROP TRIGGER refresh_receiver_updated_at ON public.receivers; + +ALTER TABLE public.receivers + ADD COLUMN link_last_sent_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN email_registered_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN public_key_registered_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN hashed_extra_info VARCHAR(64), + ADD COLUMN hashed_phone_number VARCHAR(64), + ADD COLUMN extra_info VARCHAR(64), + DROP COLUMN updated_at; + +ALTER TABLE public.receivers RENAME COLUMN created_at TO registered_at; + +-- COLUMNS encrypted_pii & extra_info +UPDATE public.receivers SET extra_info = (encrypted_pii->>'date_of_birth'); +ALTER TABLE public.receivers ALTER COLUMN extra_info SET NOT NULL; +ALTER TABLE public.receivers DROP COLUMN encrypted_pii; diff --git a/internal/db/migrations/2023-01-30.4-receiver-wallets-status.sql b/internal/db/migrations/2023-01-30.4-receiver-wallets-status.sql new file mode 100644 index 000000000..9fed78f62 --- /dev/null +++ b/internal/db/migrations/2023-01-30.4-receiver-wallets-status.sql @@ -0,0 +1,64 @@ +-- This updates the receiver_wallets by adding a status column, and also removes the status column from the receivers table. +-- The status was moved to the receiver_wallets because a receiver can have multiple wallets and would need to properly register each one of them. + +-- +migrate Up + +CREATE TYPE receiver_wallet_status AS ENUM( + 'DRAFT', + 'READY', + 'REGISTERED', + 'FLAGGED' +); + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION create_receiver_wallet_status_history(time_stamp TIMESTAMP WITH TIME ZONE, rw_status receiver_wallet_status) +RETURNS jsonb AS $$ + BEGIN + RETURN json_build_object( + 'timestamp', time_stamp, + 'status', rw_status + ); + END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +ALTER TABLE public.receiver_wallets + ADD COLUMN status receiver_wallet_status NOT NULL DEFAULT receiver_wallet_status('DRAFT'), + ADD COLUMN status_history jsonb[] NOT NULL DEFAULT ARRAY[create_receiver_wallet_status_history(NOW(), receiver_wallet_status('DRAFT'))]; + +-- COLUMN: status +UPDATE public.receiver_wallets rwOriginal + SET status = ( + CASE WHEN UPPER(r.status) IN ('READY', 'PAID', 'PARTIALLY_CASHED_OUT', 'FULLY_CASHED_OUT') THEN receiver_wallet_status('REGISTERED') + ELSE receiver_wallet_status('READY') + END + ) + FROM public.receiver_wallets rw LEFT JOIN public.receivers r ON rw.receiver_id = r.id + WHERE rwOriginal.id = rw.id; + +-- COLUMN: status_history +UPDATE public.receiver_wallets rwOriginal + SET status_history = ( + CASE WHEN rwOriginal.status = receiver_wallet_status('REGISTERED') THEN ARRAY[create_receiver_wallet_status_history(NOW(), receiver_wallet_status('REGISTERED'))] + ELSE ARRAY[create_receiver_wallet_status_history(NOW(), receiver_wallet_status('READY'))] + END + ) + FROM public.receiver_wallets rw LEFT JOIN public.receivers r ON rw.receiver_id = r.id + WHERE rwOriginal.id = rw.id; + +-- TABLE: receiver +ALTER TABLE public.receivers DROP COLUMN status; + + +-- +migrate Down + +-- TABLE: receiver +ALTER TABLE public.receivers ADD COLUMN status VARCHAR(32); + +ALTER TABLE public.receiver_wallets + DROP COLUMN status, + DROP COLUMN status_history; + +DROP FUNCTION create_receiver_wallet_status_history; + +DROP TYPE receiver_wallet_status; diff --git a/internal/db/migrations/2023-02-03.0-update-messages-add-new-columns.sql b/internal/db/migrations/2023-02-03.0-update-messages-add-new-columns.sql new file mode 100644 index 000000000..bfc0a4bc3 --- /dev/null +++ b/internal/db/migrations/2023-02-03.0-update-messages-add-new-columns.sql @@ -0,0 +1,43 @@ +-- Update the receiver table. + +-- +migrate Up + +CREATE TYPE message_status AS ENUM( + 'PENDING', + 'SUCCESS', + 'FAILURE' +); + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION create_message_status_history(time_stamp TIMESTAMP WITH TIME ZONE, m_status message_status, status_message VARCHAR) +RETURNS jsonb AS $$ + BEGIN + RETURN json_build_object( + 'timestamp', time_stamp, + 'status', m_status, + 'status_message', status_message + ); + END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +ALTER TABLE public.messages + ADD COLUMN status message_status NOT NULL DEFAULT message_status('PENDING'), + ADD COLUMN status_history jsonb[] NOT NULL DEFAULT ARRAY[create_message_status_history(NOW(), message_status('PENDING'), NULL)], + ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(); + +-- column updated_at +CREATE TRIGGER refresh_message_updated_at BEFORE UPDATE ON public.messages FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down +DROP TRIGGER refresh_message_updated_at ON public.messages; + +ALTER TABLE public.messages + DROP COLUMN status, + DROP COLUMN status_history, + DROP COLUMN updated_at; + +DROP FUNCTION create_message_status_history; + +DROP TYPE message_status; diff --git a/internal/db/migrations/2023-03-09.0-populate-static-data-countries-assets-wallets.sql b/internal/db/migrations/2023-03-09.0-populate-static-data-countries-assets-wallets.sql new file mode 100644 index 000000000..633f965da --- /dev/null +++ b/internal/db/migrations/2023-03-09.0-populate-static-data-countries-assets-wallets.sql @@ -0,0 +1,19 @@ +-- Adds new assets, countries and wallets to the database + +-- +migrate Up + +-- Add USA, BRA and COL to the countries table +INSERT INTO + public.countries (code, name) +VALUES + ('BRA', 'Brazil'), + ('USA', 'United States of America'), + ('COL', 'Colombia'); + +-- +migrate Down + +-- Remove USA, BRA and COL from the countries table +DELETE FROM + public.countries +WHERE + code IN ('BRA', 'USA', 'COL'); diff --git a/internal/db/migrations/2023-03-16.0-create-organization-table.sql b/internal/db/migrations/2023-03-16.0-create-organization-table.sql new file mode 100644 index 000000000..63e1f11cd --- /dev/null +++ b/internal/db/migrations/2023-03-16.0-create-organization-table.sql @@ -0,0 +1,33 @@ +-- This creates the organizations table. + +-- +migrate Up + +-- Table: organizations +CREATE TABLE public.organizations ( + id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(), + name VARCHAR(64) NOT NULL, + stellar_main_address VARCHAR(56) NOT NULL, + timezone_utc_offset VARCHAR(6) NOT NULL DEFAULT '+00:00', + are_payments_enabled BOOLEAN NOT NULL DEFAULT FALSE, + sms_registration_message_template VARCHAR(255) NOT NULL DEFAULT 'You have a payment waiting for you from the {{.OrganizationName}}. Click {{.RegistrationLink}} to register.', + + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + + UNIQUE (name), + CONSTRAINT organization_name_not_empty_check CHECK (char_length(name) > 1), + CONSTRAINT organization_timezone_size_check CHECK (char_length(timezone_utc_offset) = 6), + CONSTRAINT organization_sms_registration_message_template_contains_tags_check CHECK (sms_registration_message_template LIKE '%{{.OrganizationName}}%' AND sms_registration_message_template LIKE '%{{.RegistrationLink}}%') +); + +INSERT INTO public.organizations (name, stellar_main_address) VALUES ('MyCustomAid', 'GDA34JZ26FZY64XCSY46CUNSHLX762LHJXQHWWHGL5HSFRWSGBVHUFNI'); + +CREATE TRIGGER refresh_organization_updated_at BEFORE UPDATE ON public.organizations FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down + +-- Table: organizations +DROP TRIGGER refresh_organization_updated_at ON public.organizations; + +DROP TABLE public.organizations CASCADE; diff --git a/internal/db/migrations/2023-03-22.0-enforce-one-row-for-organizations-table.sql b/internal/db/migrations/2023-03-22.0-enforce-one-row-for-organizations-table.sql new file mode 100644 index 000000000..7bd1af41b --- /dev/null +++ b/internal/db/migrations/2023-03-22.0-enforce-one-row-for-organizations-table.sql @@ -0,0 +1,34 @@ +-- Update the organization table to enforce a single row in the whole table. + +-- +migrate Up + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION enforce_single_row_for_organizations() +RETURNS TRIGGER AS $$ +BEGIN + IF (SELECT COUNT(*) FROM public.organizations) != 0 THEN + RAISE EXCEPTION 'public.organizations can must contain exactly one row'; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +CREATE TRIGGER enforce_single_row_for_organizations_insert_trigger + BEFORE INSERT ON public.organizations + FOR EACH ROW + EXECUTE FUNCTION enforce_single_row_for_organizations(); + +CREATE TRIGGER enforce_single_row_for_organizations_delete_trigger + BEFORE DELETE ON public.organizations + FOR EACH ROW + EXECUTE FUNCTION enforce_single_row_for_organizations(); + + +-- +migrate Down + +DROP TRIGGER enforce_single_row_for_organizations_delete_trigger ON public.organizations; + +DROP TRIGGER enforce_single_row_for_organizations_insert_trigger ON public.organizations; + +DROP FUNCTION enforce_single_row_for_organizations; diff --git a/internal/db/migrations/2023-04-12.0-create-submitter-transactions-table.sql b/internal/db/migrations/2023-04-12.0-create-submitter-transactions-table.sql new file mode 100644 index 000000000..9d1bbc80e --- /dev/null +++ b/internal/db/migrations/2023-04-12.0-create-submitter-transactions-table.sql @@ -0,0 +1,36 @@ +-- This creates the submitter_transactions table +-- +migrate Up + +CREATE TYPE transaction_status as enum ('PENDING', 'PROCESSING', 'SENT', 'SUCCESS', 'ERROR'); + +CREATE TABLE public.submitter_transactions ( + id VARCHAR(64) NOT NULL PRIMARY KEY DEFAULT uuid_generate_v4(), + external_id VARCHAR(64), + status transaction_status NOT NULL, + status_message TEXT, + asset_code VARCHAR(12) NOT NULL, + asset_issuer VARCHAR(56) NOT NULL, + amount numeric(10,2) NOT NULL, + destination VARCHAR(64) NOT NULL, + memo VARCHAR(64), + memo_type VARCHAR(12), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + started_at TIMESTAMP WITH TIME ZONE, + sent_at TIMESTAMP WITH TIME ZONE, + completed_at TIMESTAMP WITH TIME ZONE, + stellar_transaction_hash VARCHAR(64), + retry_count INT DEFAULT 0 +); + +CREATE TABLE public.channel_accounts ( + public_key VARCHAR(64) NOT NULL PRIMARY KEY, + private_key VARCHAR(64), + heartbeat TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +-- +migrate Down + +DROP TABLE submitter_transactions; +DROP TABLE channel_accounts; +DROP TYPE transaction_status; \ No newline at end of file diff --git a/internal/db/migrations/2023-04-17.0-create-receiver_verifications-table.sql b/internal/db/migrations/2023-04-17.0-create-receiver_verifications-table.sql new file mode 100644 index 000000000..09995d03c --- /dev/null +++ b/internal/db/migrations/2023-04-17.0-create-receiver_verifications-table.sql @@ -0,0 +1,55 @@ +-- This creates the receiver_verifications table that stores the values used to verify a receiver's identity. + +-- +migrate Up +CREATE TYPE verification_type AS ENUM ( + 'DATE_OF_BIRTH', + 'PIN', + 'NATIONAL_ID_NUMBER'); + +CREATE TABLE public.receiver_verifications ( + receiver_id VARCHAR(64) NOT NULL REFERENCES public.receivers (id) ON DELETE CASCADE, + verification_field verification_type NOT NULL, + hashed_value TEXT NOT NULL, + attempts SMALLINT DEFAULT 0 NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT now() NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + confirmed_at TIMESTAMP WITH TIME ZONE, + failed_at TIMESTAMP WITH TIME ZONE, + PRIMARY KEY (receiver_id, verification_field) +); + +-- TRIGGER: updated_at +CREATE TRIGGER refresh_receiver_verifications_updated_at + BEFORE UPDATE ON public.receiver_verifications + FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + +-- Add verification_field to disbursements +ALTER TABLE public.disbursements + ADD COLUMN verification_field verification_type NOT NULL DEFAULT 'DATE_OF_BIRTH'; + +-- Remove PII field from receivers table and add external_id field +ALTER TABLE public.receivers + DROP COLUMN encrypted_pii, + ADD COLUMN external_id VARCHAR(64); + +-- Add receiver_wallet_id to payments table +ALTER TABLE public.payments + ADD COLUMN receiver_wallet_id VARCHAR(64) NOT NULL, + ADD CONSTRAINT fk_payments_receiver_wallet_id FOREIGN KEY (receiver_wallet_id) REFERENCES public.receiver_wallets (id); + +-- +migrate Down +DROP TRIGGER refresh_receiver_verifications_updated_at ON public.receiver_verifications; + +DROP TABLE public.receiver_verifications; + +ALTER TABLE public.disbursements DROP COLUMN verification_field; + +DROP TYPE verification_type; + +ALTER TABLE public.receivers + ADD COLUMN encrypted_pii jsonb, + DROP COLUMN external_id; + +ALTER TABLE public.payments + DROP COLUMN receiver_wallet_id; + diff --git a/internal/db/migrations/2023-04-21.0-add-receiver-wallets-otp.sql b/internal/db/migrations/2023-04-21.0-add-receiver-wallets-otp.sql new file mode 100644 index 000000000..5a36bfdd3 --- /dev/null +++ b/internal/db/migrations/2023-04-21.0-add-receiver-wallets-otp.sql @@ -0,0 +1,17 @@ +-- +migrate Up + +ALTER TABLE public.receiver_wallets + ADD COLUMN otp TEXT NULL; +ALTER TABLE public.receiver_wallets + ADD COLUMN otp_created_at TIMESTAMP WITH TIME ZONE NULL; +ALTER TABLE public.receiver_wallets + ADD COLUMN otp_confirmed_at TIMESTAMP WITH TIME ZONE NULL; + +-- +migrate Down + +ALTER TABLE public.receiver_wallets + DROP COLUMN otp; +ALTER TABLE public.receiver_wallets + DROP COLUMN otp_created_at; +ALTER TABLE public.receiver_wallets + DROP COLUMN otp_confirmed_at; diff --git a/internal/db/migrations/2023-04-25.0.alter-messages-table-add-receiver-wallet-id.sql b/internal/db/migrations/2023-04-25.0.alter-messages-table-add-receiver-wallet-id.sql new file mode 100644 index 000000000..7c113952f --- /dev/null +++ b/internal/db/migrations/2023-04-25.0.alter-messages-table-add-receiver-wallet-id.sql @@ -0,0 +1,24 @@ +-- +migrate Up + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION create_message_status_history(time_stamp TIMESTAMP WITH TIME ZONE, m_status message_status, status_message VARCHAR) +RETURNS jsonb AS $$ + BEGIN + RETURN jsonb_build_object( + 'timestamp', time_stamp, + 'status', m_status, + 'status_message', status_message + ); + END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +ALTER TABLE public.messages + ADD COLUMN receiver_wallet_id VARCHAR(36) NULL REFERENCES public.receiver_wallets (id), + ALTER COLUMN asset_id DROP NOT NULL; + +-- +migrate Down + +ALTER TABLE public.messages + DROP COLUMN receiver_wallet_id, + ALTER COLUMN asset_id SET NOT NULL; diff --git a/internal/db/migrations/2023-04-26.0-add-demo-wallet.sql b/internal/db/migrations/2023-04-26.0-add-demo-wallet.sql new file mode 100644 index 000000000..806b4cd77 --- /dev/null +++ b/internal/db/migrations/2023-04-26.0-add-demo-wallet.sql @@ -0,0 +1,12 @@ +-- This migration creates the sep_10_client_domain column in the public.wallets table and inserts the demo-wallet in the DB. + +-- +migrate Up + +ALTER TABLE public.wallets ADD COLUMN sep_10_client_domain VARCHAR(255) DEFAULT '' NOT NULL; + +UPDATE public.wallets SET sep_10_client_domain = substring(homepage from 'https?://([^/]+)'); +UPDATE public.wallets SET sep_10_client_domain = 'api-dev.vibrantapp.com' WHERE name = 'Vibrant Assist'; +ALTER TABLE public.wallets ALTER COLUMN deep_link_schema TYPE VARCHAR(255); + +-- +migrate Down +ALTER TABLE public.wallets DROP COLUMN sep_10_client_domain; diff --git a/internal/db/migrations/2023-05-01.0-add-sync-column-tss.sql b/internal/db/migrations/2023-05-01.0-add-sync-column-tss.sql new file mode 100644 index 000000000..930f105fd --- /dev/null +++ b/internal/db/migrations/2023-05-01.0-add-sync-column-tss.sql @@ -0,0 +1,9 @@ +-- This migration adds the `synced_at` column to the `submitter_transactions` table. +-- `synced_at` is used to track whether a transaction has been synced with the SDP. + + +-- +migrate Up +ALTER TABLE public.submitter_transactions ADD COLUMN synced_at TIMESTAMP WITH TIME ZONE NULL; + +-- +migrate Down +ALTER TABLE public.submitter_transactions DROP COLUMN synced_at; diff --git a/internal/db/migrations/2023-05-02.0-alter-organizations-table-add-logo.sql b/internal/db/migrations/2023-05-02.0-alter-organizations-table-add-logo.sql new file mode 100644 index 000000000..66a0bc1e4 --- /dev/null +++ b/internal/db/migrations/2023-05-02.0-alter-organizations-table-add-logo.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE public.organizations + ADD COLUMN logo BYTEA NULL; + +-- +migrate Down + +ALTER TABLE public.organizations + DROP COLUMN logo; diff --git a/internal/db/migrations/2023-05-23.0-alter-channel-accounts-pk-type.sql b/internal/db/migrations/2023-05-23.0-alter-channel-accounts-pk-type.sql new file mode 100644 index 000000000..1d7abb4ac --- /dev/null +++ b/internal/db/migrations/2023-05-23.0-alter-channel-accounts-pk-type.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE channel_accounts + ALTER COLUMN private_key TYPE VARCHAR(256); + +-- +migrate Down + +ALTER TABLE channel_accounts + ALTER COLUMN private_key TYPE VARCHAR(64); diff --git a/internal/db/migrations/2023-05-31.0-replace-payment-status-enum.sql b/internal/db/migrations/2023-05-31.0-replace-payment-status-enum.sql new file mode 100644 index 000000000..a48eff06f --- /dev/null +++ b/internal/db/migrations/2023-05-31.0-replace-payment-status-enum.sql @@ -0,0 +1,9 @@ +-- This is to update payment_status to change `FAILURE` to `FAILED`. + +-- +migrate Up + +ALTER TYPE payment_status RENAME VALUE 'FAILURE' TO 'FAILED'; + +-- +migrate Down + +ALTER TYPE payment_status RENAME VALUE 'FAILED' TO 'FAILURE'; \ No newline at end of file diff --git a/internal/db/migrations/2023-06-01.0-add-file-fields-to-disbursements.sql b/internal/db/migrations/2023-06-01.0-add-file-fields-to-disbursements.sql new file mode 100644 index 000000000..f7b4a50a8 --- /dev/null +++ b/internal/db/migrations/2023-06-01.0-add-file-fields-to-disbursements.sql @@ -0,0 +1,14 @@ +-- This is to add `file_name` and `file_content` to `disbursements` table. + +-- +migrate Up + +ALTER TABLE disbursements + ADD COLUMN file_content BYTEA NULL, + ADD COLUMN file_name TEXT NULL; + +-- +migrate Down + +ALTER TABLE disbursements + DROP COLUMN file_content, + DROP COLUMN file_name; + diff --git a/internal/db/migrations/2023-06-07.0-add-retry-after-column.sql b/internal/db/migrations/2023-06-07.0-add-retry-after-column.sql new file mode 100644 index 000000000..184670281 --- /dev/null +++ b/internal/db/migrations/2023-06-07.0-add-retry-after-column.sql @@ -0,0 +1,9 @@ +-- This migration adds the `retry_after` column to the `submitter_transactions` table. +-- `retry_after` is used to specify a time in which re-processing of the transaction should not be attempted until after this time. + + +-- +migrate Up +ALTER TABLE public.submitter_transactions ADD COLUMN retry_after TIMESTAMP; + +-- +migrate Down +ALTER TABLE public.submitter_transactions DROP COLUMN retry_after; \ No newline at end of file diff --git a/internal/db/migrations/2023-06-08.0-add-dryrun-message-type.sql b/internal/db/migrations/2023-06-08.0-add-dryrun-message-type.sql new file mode 100644 index 000000000..3db4c191e --- /dev/null +++ b/internal/db/migrations/2023-06-08.0-add-dryrun-message-type.sql @@ -0,0 +1,21 @@ +-- This is to add `DRY_RUN` to the `message_type` enum. + +-- +migrate Up +ALTER TYPE message_type ADD VALUE 'DRY_RUN'; + + +-- +migrate Down +CREATE TYPE temp_message_type AS ENUM ( + 'TWILIO_SMS', + 'AWS_SMS', + 'AWS_EMAIL' +); + +DELETE FROM messages WHERE type = 'DRY_RUN'; + +ALTER TABLE messages + ALTER COLUMN type TYPE temp_message_type USING type::text::temp_message_type; + +DROP TYPE message_type; + +ALTER TYPE temp_message_type RENAME TO message_type; \ No newline at end of file diff --git a/internal/db/migrations/2023-06-22.0-add-unique-constraint-wallet-table.sql b/internal/db/migrations/2023-06-22.0-add-unique-constraint-wallet-table.sql new file mode 100644 index 000000000..a5f03ae03 --- /dev/null +++ b/internal/db/migrations/2023-06-22.0-add-unique-constraint-wallet-table.sql @@ -0,0 +1,7 @@ +-- +migrate Up + +CREATE UNIQUE INDEX unique_wallets_index ON public.wallets(name, homepage, deep_link_schema); + +-- +migrate Down + +DROP INDEX IF EXISTS unique_wallets_index; diff --git a/internal/db/migrations/2023-07-05.0-tss-transactions-table-constraints.sql b/internal/db/migrations/2023-07-05.0-tss-transactions-table-constraints.sql new file mode 100644 index 000000000..855b26361 --- /dev/null +++ b/internal/db/migrations/2023-07-05.0-tss-transactions-table-constraints.sql @@ -0,0 +1,34 @@ +-- +migrate Up + +ALTER TABLE public.submitter_transactions + ADD COLUMN updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ADD COLUMN xdr_sent TEXT UNIQUE, + ADD COLUMN xdr_received TEXT UNIQUE, + ALTER COLUMN external_id SET NOT NULL, + ALTER COLUMN status SET DEFAULT 'PENDING', + ALTER COLUMN amount TYPE numeric(10,7), + ADD CONSTRAINT unique_stellar_transaction_hash UNIQUE (stellar_transaction_hash), + ADD CONSTRAINT check_retry_count CHECK (retry_count >= 0); + +CREATE UNIQUE INDEX idx_unique_external_id ON public.submitter_transactions (external_id) WHERE status != 'ERROR'; + +-- TRIGGER: updated_at +CREATE TRIGGER refresh_submitter_transactions_updated_at BEFORE UPDATE ON public.submitter_transactions FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + + +-- +migrate Down + +-- TRIGGER: updated_at +DROP TRIGGER refresh_submitter_transactions_updated_at ON public.submitter_transactions; + +DROP INDEX idx_unique_external_id; + +ALTER TABLE public.submitter_transactions + DROP COLUMN updated_at, + DROP COLUMN xdr_sent, + DROP COLUMN xdr_received, + ALTER COLUMN external_id DROP NOT NULL, + ALTER COLUMN status DROP DEFAULT, + ALTER COLUMN amount TYPE numeric(10,2), + DROP CONSTRAINT unique_stellar_transaction_hash, + DROP CONSTRAINT check_retry_count; diff --git a/internal/db/migrations/2023-07-17.0-channel-accounts-management-locks.sql b/internal/db/migrations/2023-07-17.0-channel-accounts-management-locks.sql new file mode 100644 index 000000000..6f98fee52 --- /dev/null +++ b/internal/db/migrations/2023-07-17.0-channel-accounts-management-locks.sql @@ -0,0 +1,33 @@ +-- This is to update the channel_accounts table with the locked_until_ledger_number column, for concurrent use. +-- It also deletes the unused heartbeat column and add updated_at and locked_at for improved debuggability. + +-- +migrate Up +ALTER TABLE public.channel_accounts + DROP COLUMN heartbeat, + ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + ADD COLUMN locked_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN locked_until_ledger_number INTEGER; + +-- column updated_at +CREATE TRIGGER refresh_channel_accounts_updated_at BEFORE UPDATE ON public.channel_accounts FOR EACH ROW EXECUTE PROCEDURE update_at_refresh(); + +ALTER TABLE public.submitter_transactions + ADD COLUMN locked_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN locked_until_ledger_number INTEGER, + DROP COLUMN memo, + DROP COLUMN memo_type; + +-- +migrate Down +DROP TRIGGER refresh_channel_accounts_updated_at ON public.channel_accounts; + +ALTER TABLE public.submitter_transactions + DROP COLUMN locked_at, + DROP COLUMN locked_until_ledger_number, + ADD COLUMN memo VARCHAR(64), + ADD COLUMN memo_type VARCHAR(12); + +ALTER TABLE public.channel_accounts + ADD COLUMN heartbeat TIMESTAMP WITH TIME ZONE, + DROP COLUMN updated_at, + DROP COLUMN locked_at, + DROP COLUMN locked_until_ledger_number; diff --git a/internal/db/migrations/2023-07-17.1-tss-remove-SENT-status.sql b/internal/db/migrations/2023-07-17.1-tss-remove-SENT-status.sql new file mode 100644 index 000000000..5be524fdf --- /dev/null +++ b/internal/db/migrations/2023-07-17.1-tss-remove-SENT-status.sql @@ -0,0 +1,56 @@ +-- This migration updates the submitter_transactions table by removing the SENT status. + +-- +migrate Up +-- Create new type +CREATE TYPE transaction_status_new AS ENUM ('PENDING', 'PROCESSING', 'SUCCESS', 'ERROR'); + +-- Add a new column with the new type +ALTER TABLE public.submitter_transactions ADD COLUMN status_new transaction_status_new; + +-- Copy & transform data +UPDATE public.submitter_transactions SET status_new = +CASE + WHEN status = 'SENT' THEN 'PROCESSING'::transaction_status_new + ELSE status::text::transaction_status_new +END; + +-- Drop the old column +ALTER TABLE public.submitter_transactions DROP COLUMN status; + +-- Rename the new column +ALTER TABLE public.submitter_transactions RENAME COLUMN status_new TO status; + +-- Drop old type +DROP TYPE transaction_status; + +-- Rename new type +ALTER TYPE transaction_status_new RENAME TO transaction_status; + +-- Restore index that was when we changed the enum type +CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_external_id ON public.submitter_transactions (external_id) WHERE status != 'ERROR'; + +-- +migrate Down + +-- Create old type +CREATE TYPE transaction_status_old AS ENUM ('PENDING', 'PROCESSING', 'SENT', 'SUCCESS', 'ERROR'); + +-- Add a new column with the old type +ALTER TABLE public.submitter_transactions ADD COLUMN status_old transaction_status_old; + +-- Copy data to new column +UPDATE public.submitter_transactions SET status_old = status::text::transaction_status_old; + +-- Drop the new column +ALTER TABLE public.submitter_transactions DROP COLUMN status; + +-- Rename the old column +ALTER TABLE public.submitter_transactions RENAME COLUMN status_old TO status; + +-- Drop new type +DROP TYPE transaction_status; + +-- Rename old type +ALTER TYPE transaction_status_old RENAME TO transaction_status; + +-- Restore index that was when we changed the enum type +CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_external_id ON public.submitter_transactions (external_id) WHERE status != 'ERROR'; diff --git a/internal/db/migrations/2023-07-17.2-add-status-history-column-submitter-transactions-table.sql b/internal/db/migrations/2023-07-17.2-add-status-history-column-submitter-transactions-table.sql new file mode 100644 index 000000000..75067edde --- /dev/null +++ b/internal/db/migrations/2023-07-17.2-add-status-history-column-submitter-transactions-table.sql @@ -0,0 +1,30 @@ +-- This migration updates the submitter_transactions table by addring the status_history table, for increased debuggability. + +-- +migrate Up + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION create_submitter_transactions_status_history(time_stamp TIMESTAMP WITH TIME ZONE, tss_status transaction_status, status_message VARCHAR, stellar_transaction_hash TEXT, xdr_sent TEXT, xdr_received TEXT) +RETURNS jsonb AS $$ + BEGIN + RETURN json_build_object( + 'timestamp', time_stamp, + 'status', tss_status, + 'status_message', status_message, + 'stellar_transaction_hash', stellar_transaction_hash, + 'xdr_sent', xdr_sent, + 'xdr_received', xdr_received + ); + END; +$$ LANGUAGE plpgsql; +-- +migrate StatementEnd + +ALTER TABLE public.submitter_transactions + ADD COLUMN status_history jsonb[] NULL DEFAULT ARRAY[create_submitter_transactions_status_history(NOW(), 'PENDING', NULL, NULL, NULL, NULL)]; + + +-- +migrate Down + +ALTER TABLE public.submitter_transactions + DROP COLUMN status_history; + +DROP FUNCTION IF EXISTS create_submitter_transactions_status_history; diff --git a/internal/db/migrations/2023-07-20.0-tss-remove-retry_after-and-rename-retry_count.sql b/internal/db/migrations/2023-07-20.0-tss-remove-retry_after-and-rename-retry_count.sql new file mode 100644 index 000000000..03fac8593 --- /dev/null +++ b/internal/db/migrations/2023-07-20.0-tss-remove-retry_after-and-rename-retry_count.sql @@ -0,0 +1,30 @@ +-- This migration removes the unused column retry_after and renames retry_count to attempts_count, updating its value accordingly. +-- Also, updates other columns that were not properly configured. + +-- +migrate Up + +ALTER TABLE public.submitter_transactions DROP COLUMN retry_after; + +ALTER TABLE public.submitter_transactions RENAME COLUMN retry_count TO attempts_count; + +UPDATE public.submitter_transactions SET attempts_count = attempts_count + 1 WHERE status != 'PENDING'; + +--configuring the columns that were not properly configured: +ALTER TABLE public.submitter_transactions + ALTER COLUMN destination TYPE VARCHAR(56), + ALTER COLUMN status SET DEFAULT 'PENDING', + ALTER COLUMN status SET NOT NULL; + +-- +migrate Down + +ALTER TABLE public.submitter_transactions ADD COLUMN retry_after TIMESTAMPTZ; + +ALTER TABLE public.submitter_transactions RENAME COLUMN attempts_count TO retry_count; + +UPDATE public.submitter_transactions SET retry_count = retry_count - 1 WHERE status != 'PENDING' AND retry_count > 0; + +--reverting configuration for the columns that were not properly configured: +ALTER TABLE public.submitter_transactions + ALTER COLUMN destination TYPE VARCHAR(64), + ALTER COLUMN status DROP DEFAULT, + ALTER COLUMN status DROP NOT NULL; diff --git a/internal/db/migrations/2023-08-02.0-organizations-table-add-approver-function.sql b/internal/db/migrations/2023-08-02.0-organizations-table-add-approver-function.sql new file mode 100644 index 000000000..93ae0d640 --- /dev/null +++ b/internal/db/migrations/2023-08-02.0-organizations-table-add-approver-function.sql @@ -0,0 +1,12 @@ +-- +migrate Up + +ALTER TABLE public.organizations + ADD COLUMN is_approval_required boolean NOT NULL DEFAULT false; + +COMMENT ON COLUMN public.organizations.is_approval_required + IS 'Column used to enable disbursement approval for organizations, requiring multiple users to start a disbursement.'; + +-- +migrate Down + +ALTER TABLE public.organizations + DROP COLUMN is_approval_required; diff --git a/internal/db/migrations/2023-08-10.0-countries-seed.sql b/internal/db/migrations/2023-08-10.0-countries-seed.sql new file mode 100644 index 000000000..1b4bed857 --- /dev/null +++ b/internal/db/migrations/2023-08-10.0-countries-seed.sql @@ -0,0 +1,221 @@ +-- +migrate Up + +INSERT INTO public.countries + (code, name) +VALUES + ('AFG', 'Afghanistan'), + ('ALB', 'Albania'), + ('DZA', 'Algeria'), + ('ASM', 'American Samoa'), + ('AND', 'Andorra'), + ('AGO', 'Angola'), + ('ATG', 'Antigua and Barbuda'), + ('ARG', 'Argentina'), + ('ARM', 'Armenia'), + ('ABW', 'Aruba'), + ('AUS', 'Australia'), + ('AUT', 'Austria'), + ('AZE', 'Azerbaijan'), + ('BHS', 'Bahamas'), + ('BHR', 'Bahrain'), + ('BGD', 'Bangladesh'), + ('BRB', 'Barbados'), + ('BLR', 'Belarus'), + ('BEL', 'Belgium'), + ('BLZ', 'Belize'), + ('BEN', 'Benin'), + ('BMU', 'Bermuda'), + ('BTN', 'Bhutan'), + ('BOL', 'Bolivia'), + ('BIH', 'Bosnia and Herzegovina'), + ('BWA', 'Botswana'), + ('BRA', 'Brazil'), + ('BRN', 'Brunei'), + ('BGR', 'Bulgaria'), + ('BFA', 'Burkina Faso'), + ('BDI', 'Burundi'), + ('CPV', 'Cabo Verde'), + ('KHM', 'Cambodia'), + ('CMR', 'Cameroon'), + ('CAN', 'Canada'), + ('CAF', 'Central African Republic'), + ('TCD', 'Chad'), + ('CHL', 'Chile'), + ('CHN', 'China'), + ('COL', 'Colombia'), + ('COM', 'Comoros (the)'), + ('COG', 'Congo (the)'), + ('COK', 'Cook Islands (the)'), + ('CRI', 'Costa Rica'), + ('HRV', 'Croatia'), + ('CYP', 'Cyprus'), + ('CZE', 'Czechia'), + ('CIV', 'CΓ΄te d''Ivoire (Ivory Coast)'), + ('COD', 'Democratic Republic of the Congo'), + ('DNK', 'Denmark'), + ('DJI', 'Djibouti'), + ('DMA', 'Dominica'), + ('DOM', 'Dominican Republic'), + ('ECU', 'Ecuador'), + ('EGY', 'Egypt'), + ('SLV', 'El Salvador'), + ('GNQ', 'Equatorial Guinea'), + ('ERI', 'Eritrea'), + ('EST', 'Estonia'), + ('SWZ', 'Eswatini'), + ('ETH', 'Ethiopia'), + ('FJI', 'Fiji'), + ('FIN', 'Finland'), + ('FRA', 'France'), + ('GUF', 'French Guiana'), + ('PYF', 'French Polynesia'), + ('ATF', 'French Southern Territories (the)'), + ('GAB', 'Gabon'), + ('GMB', 'Gambia (the)'), + ('GEO', 'Georgia'), + ('DEU', 'Germany'), + ('GHA', 'Ghana'), + ('GRC', 'Greece'), + ('GRL', 'Greenland'), + ('GRD', 'Grenada'), + ('GUM', 'Guam'), + ('GTM', 'Guatemala'), + ('GIN', 'Guinea'), + ('GNB', 'Guinea-Bissau'), + ('GUY', 'Guyana'), + ('HTI', 'Haiti'), + ('HND', 'Honduras'), + ('HUN', 'Hungary'), + ('ISL', 'Iceland'), + ('IND', 'India'), + ('IDN', 'Indonesia'), + ('IRQ', 'Iraq'), + ('IRL', 'Ireland'), + ('ISR', 'Israel'), + ('ITA', 'Italy'), + ('JAM', 'Jamaica'), + ('JPN', 'Japan'), + ('JOR', 'Jordan'), + ('KAZ', 'Kazakhstan'), + ('KEN', 'Kenya'), + ('KIR', 'Kiribati'), + ('KOR', 'South Korea'), + ('KWT', 'Kuwait'), + ('KGZ', 'Kyrgyzstan'), + ('LAO', 'Laos'), + ('LVA', 'Latvia'), + ('LBN', 'Lebanon'), + ('LSO', 'Lesotho'), + ('LBR', 'Liberia'), + ('LBY', 'Libya'), + ('LIE', 'Liechtenstein'), + ('LTU', 'Lithuania'), + ('LUX', 'Luxembourg'), + ('MDG', 'Madagascar'), + ('MWI', 'Malawi'), + ('MYS', 'Malaysia'), + ('MDV', 'Maldives'), + ('MLI', 'Mali'), + ('MLT', 'Malta'), + ('MHL', 'Marshall Islands (the)'), + ('MTQ', 'Martinique'), + ('MRT', 'Mauritania'), + ('MUS', 'Mauritius'), + ('MEX', 'Mexico'), + ('FSM', 'Micronesia'), + ('MDA', 'Moldova'), + ('MCO', 'Monaco'), + ('MNG', 'Mongolia'), + ('MNE', 'Montenegro'), + ('MAR', 'Morocco'), + ('MOZ', 'Mozambique'), + ('MMR', 'Myanmar'), + ('NAM', 'Namibia'), + ('NRU', 'Nauru'), + ('NPL', 'Nepal'), + ('NLD', 'Netherlands (the)'), + ('NZL', 'New Zealand'), + ('NIC', 'Nicaragua'), + ('NER', 'Niger'), + ('NGA', 'Nigeria'), + ('MKD', 'North Macedonia (Republic of)'), + ('NOR', 'Norway'), + ('OMN', 'Oman'), + ('PAK', 'Pakistan'), + ('PLW', 'Palau'), + ('PAN', 'Panama'), + ('PNG', 'Papua New Guinea'), + ('PRY', 'Paraguay'), + ('PER', 'Peru'), + ('PHL', 'Philippines (the)'), + ('POL', 'Poland'), + ('PRT', 'Portugal'), + ('PRI', 'Puerto Rico'), + ('QAT', 'Qatar'), + ('ROU', 'Romania'), + ('RUS', 'Russia'), + ('RWA', 'Rwanda'), + ('REU', 'RΓ©union'), + ('BLM', 'Saint Barts'), + ('KNA', 'Saint Kitts and Nevis'), + ('LCA', 'Saint Lucia'), + ('MAF', 'Saint Martin'), + ('VCT', 'Saint Vincent and the Grenadines'), + ('WSM', 'Samoa'), + ('SMR', 'San Marino'), + ('STP', 'Sao Tome and Principe'), + ('SAU', 'Saudi Arabia'), + ('SEN', 'Senegal'), + ('SRB', 'Serbia'), + ('SYC', 'Seychelles'), + ('SLE', 'Sierra Leone'), + ('SGP', 'Singapore'), + ('SVK', 'Slovakia'), + ('SVN', 'Slovenia'), + ('SLB', 'Solomon Islands'), + ('SOM', 'Somalia'), + ('ZAF', 'South Africa'), + ('SSD', 'South Sudan'), + ('ESP', 'Spain'), + ('LKA', 'Sri Lanka'), + ('SDN', 'Sudan (the)'), + ('SUR', 'Suriname'), + ('SWE', 'Sweden'), + ('CHE', 'Switzerland'), + ('TWN', 'Taiwan'), + ('TJK', 'Tajikistan'), + ('TZA', 'Tanzania'), + ('THA', 'Thailand'), + ('TLS', 'Timor-Leste'), + ('TGO', 'Togo'), + ('TON', 'Tonga'), + ('TTO', 'Trinidad and Tobago'), + ('TUN', 'Tunisia'), + ('TUR', 'Turkey'), + ('TKM', 'Turkmenistan'), + ('TCA', 'Turks and Caicos Islands'), + ('TUV', 'Tuvalu'), + ('UGA', 'Uganda'), + ('UKR', 'Ukraine'), + ('ARE', 'United Arab Emirates'), + ('GBR', 'United Kingdom'), + ('UMI', 'United States Minor Outlying Islands'), + ('USA', 'United States of America'), + ('URY', 'Uruguay'), + ('UZB', 'Uzbekistan'), + ('VUT', 'Vanuatu'), + ('VEN', 'Venezuela'), + ('VNM', 'Vietnam'), + ('VGB', 'Virgin Islands (British)'), + ('VIR', 'Virgin Islands (U.S.)'), + ('YEM', 'Yemen'), + ('ZMB', 'Zambia'), + ('ZWE', 'Zimbabwe') +ON CONFLICT DO NOTHING; + +-- +migrate Down + +DELETE FROM + public.countries +WHERE + code NOT IN ('UKR', 'BRA', 'COL', 'USA'); diff --git a/internal/db/migrations/main.go b/internal/db/migrations/main.go new file mode 100644 index 000000000..91cca1c33 --- /dev/null +++ b/internal/db/migrations/main.go @@ -0,0 +1,6 @@ +package migrations + +import "embed" + +//go:embed *.sql +var FS embed.FS diff --git a/internal/db/sql_exec_with_metrics.go b/internal/db/sql_exec_with_metrics.go new file mode 100644 index 000000000..b357d3190 --- /dev/null +++ b/internal/db/sql_exec_with_metrics.go @@ -0,0 +1,148 @@ +package db + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/jmoiron/sqlx" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" +) + +type QueryType string + +const ( + DeleteQueryType QueryType = "DELETE" + InsertQueryType QueryType = "INSERT" + SelectQueryType QueryType = "SELECT" + UndefinedQueryType QueryType = "UNDEFINED" + UpdateQueryType QueryType = "UPDATE" +) + +func NewSQLExecuterWithMetrics(sqlExec SQLExecuter, monitorServiceInterface monitor.MonitorServiceInterface) (*SQLExecuterWithMetrics, error) { + return &SQLExecuterWithMetrics{ + SQLExecuter: sqlExec, + monitorServiceInterface: monitorServiceInterface, + }, nil +} + +// SQLExecuterWithMetrics is a wrapper around SQLExecuter that implements the monitoring service. +type SQLExecuterWithMetrics struct { + SQLExecuter + monitorServiceInterface monitor.MonitorServiceInterface +} + +// make sure SQLExecuterWithMetrics implements SQLExecuter: +var _ SQLExecuter = (*SQLExecuterWithMetrics)(nil) + +// monitorDBQueryDuration is a method that helps monitor the db query duration using the monitoring service. +func (sqlExec *SQLExecuterWithMetrics) monitorDBQueryDuration(duration time.Duration, query string, err error) { + labels := monitor.DBQueryLabels{ + QueryType: string(getQueryType(query)), + } + errMetric := sqlExec.monitorServiceInterface.MonitorDBQueryDuration(duration, getMetricTag(err), labels) + if errMetric != nil { + log.Errorf("Error trying to monitor db query duration: %s", errMetric) + } +} + +// QueryContext is a wrapper around QueryerContext interface QueryContext that includes monitoring the db query. +func (sqlExec *SQLExecuterWithMetrics) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + then := time.Now() + + err := sqlExec.SQLExecuter.GetContext(ctx, dest, query, args...) + + duration := time.Since(then) + + sqlExec.monitorDBQueryDuration(duration, query, err) + + return err +} + +// SelectContext is a wrapper around DBConnetionPool SelectContext that includes monitoring the db query. +func (sqlExec *SQLExecuterWithMetrics) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + then := time.Now() + + err := sqlExec.SQLExecuter.SelectContext(ctx, dest, query, args...) + + duration := time.Since(then) + + sqlExec.monitorDBQueryDuration(duration, query, err) + + return err +} + +// ExecContext is a wrapper around DBConnetionPool ExecContext that includes monitoring the db query. +func (sqlExec *SQLExecuterWithMetrics) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + then := time.Now() + + result, err := sqlExec.SQLExecuter.ExecContext(ctx, query, args...) + + duration := time.Since(then) + + sqlExec.monitorDBQueryDuration(duration, query, err) + + return result, err +} + +// QueryContext is a wrapper around QueryerContext interface QueryContext that includes monitoring the db query. +func (sqlExec *SQLExecuterWithMetrics) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + then := time.Now() + + rows, err := sqlExec.SQLExecuter.QueryContext(ctx, query, args...) + + duration := time.Since(then) + + sqlExec.monitorDBQueryDuration(duration, query, err) + + return rows, err +} + +// QueryxContext is a wrapper around QueryerContext interface QueryxContext that includes monitoring the db query. +func (sqlExec *SQLExecuterWithMetrics) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + then := time.Now() + + rows, err := sqlExec.SQLExecuter.QueryxContext(ctx, query, args...) + + duration := time.Since(then) + + sqlExec.monitorDBQueryDuration(duration, query, err) + + return rows, err +} + +// QueryRowxContext is a wrapper around QueryerContext interface QueryRowxContext that includes monitoring the db query. +func (sqlExec *SQLExecuterWithMetrics) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + then := time.Now() + + row := sqlExec.SQLExecuter.QueryRowxContext(ctx, query, args...) + + duration := time.Since(then) + + sqlExec.monitorDBQueryDuration(duration, query, row.Err()) + + return row +} + +// getMetricTag is a helper that returns the correct metric tag to be used in the monitoring service. +func getMetricTag(err error) monitor.MetricTag { + if err != nil { + return monitor.FailureQueryDurationTag + } + + return monitor.SuccessfulQueryDurationTag +} + +// getQueryType is a helper that return the type of query being executed. +func getQueryType(query string) QueryType { + words := strings.Fields(strings.TrimSpace(query)) + for _, word := range []string{"DELETE", "INSERT", "SELECT", "UPDATE"} { + if word == words[0] { + return QueryType(word) + } + } + // Fresh out of ideas. + return UndefinedQueryType +} diff --git a/internal/db/sql_exec_with_metrics_test.go b/internal/db/sql_exec_with_metrics_test.go new file mode 100644 index 000000000..b2a38603a --- /dev/null +++ b/internal/db/sql_exec_with_metrics_test.go @@ -0,0 +1,438 @@ +package db + +import ( + "context" + "fmt" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestSQLExecWithMetrics_GetContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + sqlExecWithMetrics, err := NewSQLExecuterWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + var mDest string + + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ` + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + t.Run("query successful in GetContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "SELECT", + } + mQuery := "SELECT a.code FROM assets a WHERE a.issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On( + "MonitorDBQueryDuration", + mock.AnythingOfType("time.Duration"), + monitor.SuccessfulQueryDurationTag, + mLabels, + ).Return(nil).Once() + + err := sqlExecWithMetrics.GetContext(ctx, &mDest, mQuery) + require.NoError(t, err) + + expected := "USDC" + assert.Equal(t, expected, mDest) + + mMonitorService.AssertExpectations(t) + }) + + t.Run("query failure in GetContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "SELECT", + } + mQuery := "SELECT a.code FROM assets a WHERE a.issuer = 'invalid_issuer'" + + mMonitorService.On("MonitorDBQueryDuration", mock.AnythingOfType("time.Duration"), monitor.FailureQueryDurationTag, mLabels).Return(nil).Once() + + err := sqlExecWithMetrics.GetContext(ctx, &mDest, mQuery) + require.EqualError(t, err, "sql: no rows in result set") + + mMonitorService.AssertExpectations(t) + }) +} + +func TestSQLExecWithMetrics_SelectContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + sqlExecWithMetrics, err := NewSQLExecuterWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + var mDest []string + + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ` + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + t.Run("query successful in SelectContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "SELECT", + } + mQuery := "SELECT a.code FROM assets a WHERE a.issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On( + "MonitorDBQueryDuration", + mock.AnythingOfType("time.Duration"), + monitor.SuccessfulQueryDurationTag, + mLabels, + ).Return(nil).Once() + + err := sqlExecWithMetrics.SelectContext(ctx, &mDest, mQuery) + require.NoError(t, err) + + expected := []string{"USDC", "EURT"} + assert.Equal(t, expected, mDest) + + mMonitorService.AssertExpectations(t) + }) + + t.Run("query failure in SelectContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "UNDEFINED", + } + mQuery := "invalid query" + + mMonitorService.On("MonitorDBQueryDuration", mock.AnythingOfType("time.Duration"), monitor.FailureQueryDurationTag, mLabels).Return(nil).Once() + + err := sqlExecWithMetrics.SelectContext(ctx, &mDest, mQuery) + require.EqualError(t, err, `pq: syntax error at or near "invalid"`) + + mMonitorService.AssertExpectations(t) + }) +} + +func TestSQLExecWithMetrics_QueryContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + sqlExecWithMetrics, err := NewSQLExecuterWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ` + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + t.Run("query successful in QueryContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "SELECT", + } + mQuery := "SELECT a.code FROM assets a WHERE a.issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On( + "MonitorDBQueryDuration", + mock.AnythingOfType("time.Duration"), + monitor.SuccessfulQueryDurationTag, + mLabels, + ).Return(nil).Once() + + rows, err := sqlExecWithMetrics.QueryContext(ctx, mQuery) + require.NoError(t, err) + defer rows.Close() + + expected := []string{"USDC", "EURT"} + for rows.Next() { + var code string + err := rows.Scan(&code) + require.NoError(t, err) + + assert.Contains(t, expected, code) + } + + mMonitorService.AssertExpectations(t) + }) + + t.Run("query failure in QueryContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "UNDEFINED", + } + mQuery := "invalid query" + + mMonitorService.On("MonitorDBQueryDuration", mock.AnythingOfType("time.Duration"), monitor.FailureQueryDurationTag, mLabels).Return(nil).Once() + + rows, err := sqlExecWithMetrics.QueryContext(ctx, mQuery) + require.EqualError(t, err, `pq: syntax error at or near "invalid"`) + + assert.Nil(t, rows) + + mMonitorService.AssertExpectations(t) + }) +} + +func TestSQLExecWithMetrics_QueryxContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + sqlExecWithMetrics, err := NewSQLExecuterWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ` + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + t.Run("query successful in QueryxContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "SELECT", + } + mQuery := "SELECT a.code FROM assets a WHERE a.issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On( + "MonitorDBQueryDuration", + mock.AnythingOfType("time.Duration"), + monitor.SuccessfulQueryDurationTag, + mLabels, + ).Return(nil).Once() + + rows, err := sqlExecWithMetrics.QueryxContext(ctx, mQuery) + require.NoError(t, err) + defer rows.Close() + + expected := []string{"USDC", "EURT"} + for rows.Next() { + var code string + err := rows.Scan(&code) + require.NoError(t, err) + + assert.Contains(t, expected, code) + } + + mMonitorService.AssertExpectations(t) + }) + + t.Run("query failure in QueryxContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "UNDEFINED", + } + mQuery := "invalid query" + + mMonitorService.On("MonitorDBQueryDuration", mock.AnythingOfType("time.Duration"), monitor.FailureQueryDurationTag, mLabels).Return(nil).Once() + + rows, err := sqlExecWithMetrics.QueryxContext(ctx, mQuery) + require.EqualError(t, err, `pq: syntax error at or near "invalid"`) + + assert.Nil(t, rows) + + mMonitorService.AssertExpectations(t) + }) +} + +func TestSQLExecWithMetrics_QueryRowxContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + sqlExecWithMetrics, err := NewSQLExecuterWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ` + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + t.Run("query successful in QueryRowxContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "SELECT", + } + mQuery := "SELECT a.code FROM assets a WHERE a.issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On( + "MonitorDBQueryDuration", + mock.AnythingOfType("time.Duration"), + monitor.SuccessfulQueryDurationTag, + mLabels, + ).Return(nil).Once() + + var code string + err := sqlExecWithMetrics.QueryRowxContext(ctx, mQuery).Scan(&code) + require.NoError(t, err) + + expected := "USDC" + assert.Contains(t, expected, code) + + mMonitorService.AssertExpectations(t) + }) + + t.Run("query failure in QueryRowxContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "UNDEFINED", + } + mQuery := "invalid query" + + mMonitorService.On("MonitorDBQueryDuration", mock.AnythingOfType("time.Duration"), monitor.FailureQueryDurationTag, mLabels).Return(nil).Once() + + var code string + err := sqlExecWithMetrics.QueryRowxContext(ctx, mQuery).Scan(&code) + require.EqualError(t, err, `pq: syntax error at or near "invalid"`) + + mMonitorService.AssertExpectations(t) + }) +} + +func TestSQLExecWithMetrics_ExecContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + sqlExecWithMetrics, err := NewSQLExecuterWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + const query = ` + INSERT INTO assets + (code, issuer) + VALUES + ($1, $2) + ` + _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + require.NoError(t, err) + + t.Run("query successful in ExecContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "UPDATE", + } + mQuery := "UPDATE assets SET code = $1 WHERE issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On( + "MonitorDBQueryDuration", + mock.AnythingOfType("time.Duration"), + monitor.SuccessfulQueryDurationTag, + mLabels, + ).Return(nil).Once() + + result, err := sqlExecWithMetrics.ExecContext(ctx, mQuery, "EURT") + require.NoError(t, err) + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err) + assert.Equal(t, rowsAffected, int64(1)) + + mMonitorService.AssertExpectations(t) + }) + + t.Run("query failure in ExecContext", func(t *testing.T) { + mLabels := monitor.DBQueryLabels{ + QueryType: "UPDATE", + } + mQuery := "UPDATE invalid_table SET code = $1 WHERE issuer = 'GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC'" + + mMonitorService.On("MonitorDBQueryDuration", mock.AnythingOfType("time.Duration"), monitor.FailureQueryDurationTag, mLabels).Return(nil).Once() + + _, err := sqlExecWithMetrics.ExecContext(ctx, mQuery, "EURT") + require.EqualError(t, err, `pq: relation "invalid_table" does not exist`) + + mMonitorService.AssertExpectations(t) + }) +} + +func TestSQLExecWithMetrics_getMetricTag(t *testing.T) { + t.Run("return successful metric tag", func(t *testing.T) { + metricTag := getMetricTag(nil) + + assert.Equal(t, monitor.SuccessfulQueryDurationTag, metricTag) + }) + + t.Run("return failure metric tag", func(t *testing.T) { + metricTag := getMetricTag(fmt.Errorf("get failed")) + + assert.Equal(t, monitor.FailureQueryDurationTag, metricTag) + }) +} + +func TestSQLExecWithMetrics_getQueryType(t *testing.T) { + testCases := []struct { + query string + expectedQueryType QueryType + }{ + {query: "SELECT * FROM mock_table", expectedQueryType: SelectQueryType}, + {query: "UPDATE mock_table SET mock = 'mock' WHERE id = 1", expectedQueryType: UpdateQueryType}, + {query: "INSERT INTO mock_table (id) VALUES (1)", expectedQueryType: InsertQueryType}, + {query: "DELETE FROM mock_table WHERE id = 1", expectedQueryType: DeleteQueryType}, + {query: "invalid query", expectedQueryType: UndefinedQueryType}, + } + for _, tc := range testCases { + t.Run("get query type for query: "+tc.query, func(t *testing.T) { + queryType := getQueryType(tc.query) + + assert.Equal(t, tc.expectedQueryType, queryType) + }) + } +} diff --git a/internal/dependencyinjection/crash_tracker.go b/internal/dependencyinjection/crash_tracker.go new file mode 100644 index 000000000..32f7508d9 --- /dev/null +++ b/internal/dependencyinjection/crash_tracker.go @@ -0,0 +1,41 @@ +package dependencyinjection + +import ( + "context" + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" +) + +const CrashTrackerInstanceName = "crash_tracker_instance" + +// buildCrashTrackerInstanceName sets up a instance name for the crash tracker type +// to either be created and stored, also retrived later, so we can have a instance +// for each type at the same time. +func buildCrashTrackerInstanceName(crashTrackerType crashtracker.CrashTrackerType) string { + return fmt.Sprintf("%s-%s", CrashTrackerInstanceName, string(crashTrackerType)) +} + +// NewCrashTracker creates a new crash tracker instance, or retrives a instance that +// was already created before. +func NewCrashTracker(ctx context.Context, opts crashtracker.CrashTrackerOptions) (crashtracker.CrashTrackerClient, error) { + instanceName := buildCrashTrackerInstanceName(opts.CrashTrackerType) + + // Already initialized + if instance, ok := dependenciesStoreMap[instanceName]; ok { + if crashTrackerInstance, ok := instance.(crashtracker.CrashTrackerClient); ok { + return crashTrackerInstance, nil + } + return nil, fmt.Errorf("error trying to cast crash tracker instance") + } + + // Setup crash tracker instance + newCrashTracker, err := crashtracker.GetClient(ctx, opts) + if err != nil { + return nil, fmt.Errorf("error creating a new crash tracker instance: %w", err) + } + + setInstance(instanceName, newCrashTracker) + + return newCrashTracker, nil +} diff --git a/internal/dependencyinjection/crash_tracker_test.go b/internal/dependencyinjection/crash_tracker_test.go new file mode 100644 index 000000000..ea444f242 --- /dev/null +++ b/internal/dependencyinjection/crash_tracker_test.go @@ -0,0 +1,59 @@ +package dependencyinjection + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_dependencyinjection_buildCrashTrackerInstanceName(t *testing.T) { + testCrashTrackerType := crashtracker.CrashTrackerTypeSentry + result := buildCrashTrackerInstanceName(testCrashTrackerType) + assert.Equal(t, "crash_tracker_instance-SENTRY", result) +} + +func Test_dependencyinjection_NewCrashTracker(t *testing.T) { + ctx := context.Background() + t.Run("should create and return the same instrance on the second call", func(t *testing.T) { + ClearInstancesTestHelper(t) + + testSentryOptions := crashtracker.CrashTrackerOptions{ + CrashTrackerType: crashtracker.CrashTrackerTypeSentry, + } + + gotClient, err := NewCrashTracker(ctx, testSentryOptions) + require.NoError(t, err) + + gotClientDuplicate, err := NewCrashTracker(ctx, testSentryOptions) + require.NoError(t, err) + + assert.Equal(t, &gotClient, &gotClientDuplicate) + }) + + t.Run("should return an error on a invalid option", func(t *testing.T) { + ClearInstancesTestHelper(t) + + testInvalidOptions := crashtracker.CrashTrackerOptions{} + + gotClient, err := NewCrashTracker(ctx, testInvalidOptions) + assert.Nil(t, gotClient) + assert.EqualError(t, err, `error creating a new crash tracker instance: unknown crash tracker type: ""`) + }) + + t.Run("should return an error on a invalid instance", func(t *testing.T) { + ClearInstancesTestHelper(t) + + testSentryOptions := crashtracker.CrashTrackerOptions{ + CrashTrackerType: crashtracker.CrashTrackerTypeSentry, + } + + setInstance(buildCrashTrackerInstanceName(testSentryOptions.CrashTrackerType), false) + + gotClient, err := NewCrashTracker(ctx, testSentryOptions) + assert.Nil(t, gotClient) + assert.EqualError(t, err, "error trying to cast crash tracker instance") + }) +} diff --git a/internal/dependencyinjection/email_client.go b/internal/dependencyinjection/email_client.go new file mode 100644 index 000000000..52dbd6222 --- /dev/null +++ b/internal/dependencyinjection/email_client.go @@ -0,0 +1,53 @@ +package dependencyinjection + +import ( + "fmt" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" +) + +const EmailClientInstanceName = "email_client_instance" + +type EmailClientOptions struct { + EmailType message.MessengerType + MessengerOptions *message.MessengerOptions +} + +// buildEmailClientInstanceName sets up a instance name for the email messenger type +// to either be created and stored, also retrived later, so we can have a instance +// for each type at the same time. +func buildEmailClientInstanceName(emailClientType message.MessengerType) string { + return fmt.Sprintf("%s-%s", EmailClientInstanceName, string(emailClientType)) +} + +// NewEmailClient creates a new email client instance, or retrives a instance that +// was already created before. +func NewEmailClient(opts EmailClientOptions) (message.MessengerClient, error) { + if !opts.EmailType.IsEmail() { + return nil, fmt.Errorf("trying to create a Email client with a non-supported Email type: %q", opts.EmailType) + } + + if opts.MessengerOptions == nil { + opts.MessengerOptions = &message.MessengerOptions{} + } + opts.MessengerOptions.MessengerType = opts.EmailType + + // If there is already an instance of the service, we return the same instance + instanceName := buildEmailClientInstanceName(opts.MessengerOptions.MessengerType) + if instance, ok := dependenciesStoreMap[instanceName]; ok { + if emailClientInstance, ok := instance.(message.MessengerClient); ok { + return emailClientInstance, nil + } + return nil, fmt.Errorf("trying to cast pre-existing Email client for depencency injection") + } + + log.Infof("βš™οΈ Setting Email client to: %v", opts.MessengerOptions.MessengerType) + messengerClient, err := message.GetClient(*opts.MessengerOptions) + if err != nil { + return nil, fmt.Errorf("creating Email client: %w", err) + } + + setInstance(instanceName, messengerClient) + return messengerClient, nil +} diff --git a/internal/dependencyinjection/email_client_test.go b/internal/dependencyinjection/email_client_test.go new file mode 100644 index 000000000..728e45575 --- /dev/null +++ b/internal/dependencyinjection/email_client_test.go @@ -0,0 +1,81 @@ +package dependencyinjection + +import ( + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewEmailClient(t *testing.T) { + t.Run("should return an error on a invalid EMAIL type", func(t *testing.T) { + defer ClearInstancesTestHelper(t) + + mockEmailClientOptions := EmailClientOptions{EmailType: "foo-bar"} + + gotClient, err := NewEmailClient(mockEmailClientOptions) + require.Nil(t, gotClient) + require.Error(t, err) + assert.EqualError(t, err, `trying to create a Email client with a non-supported Email type: "foo-bar"`) + }) + + t.Run("should return the same instance when called twice for the same Email type", func(t *testing.T) { + defer ClearInstancesTestHelper(t) + + // STEP 1: assert that DRY_RUN email client should not be instantiated more than once + mockEmailClientOptions := EmailClientOptions{ + EmailType: message.MessengerTypeDryRun, + } + + dryRunClient1, err := NewEmailClient(mockEmailClientOptions) + require.NoError(t, err) + dryRunClient2, err := NewEmailClient(mockEmailClientOptions) + require.NoError(t, err) + assert.Equal(t, &dryRunClient1, &dryRunClient2) + + // STEP 2: assert that AWS email client should not be instantiated more than once + mockEmailClientOptions = EmailClientOptions{ + EmailType: message.MessengerTypeAWSEmail, + MessengerOptions: &message.MessengerOptions{ + Environment: "dev", + AWSAccessKeyID: "testtesttesttesttesttest", + AWSSecretAccessKey: "testtesttesttesttesttest", + AWSRegion: "testtesttesttesttesttest", + AWSSESSenderID: "test_email@email.com", + }, + } + + awsClient1, err := NewEmailClient(mockEmailClientOptions) + require.NoError(t, err) + awsClient2, err := NewEmailClient(mockEmailClientOptions) + require.NoError(t, err) + assert.Equal(t, &awsClient1, &awsClient2) + + // STEP 3: assert that twilio and aws clients are different + assert.NotEqual(t, &dryRunClient1, &awsClient1) + assert.NotEqual(t, dryRunClient1, awsClient1) + }) + + t.Run("should return an error on a invalid instance", func(t *testing.T) { + defer ClearInstancesTestHelper(t) + + mockTestEmailClientOptions := EmailClientOptions{ + EmailType: message.MessengerTypeAWSEmail, + MessengerOptions: &message.MessengerOptions{ + Environment: "test", + AWSAccessKeyID: "testtesttesttesttesttest", + AWSSecretAccessKey: "testtesttesttesttesttest", + AWSRegion: "testtesttesttesttesttest", + AWSSESSenderID: "test_email@email.com", + }, + } + + preExistingEmailClientWithInvalidType := struct{}{} + setInstance(buildEmailClientInstanceName(message.MessengerTypeAWSEmail), preExistingEmailClientWithInvalidType) + + gotClient, err := NewEmailClient(mockTestEmailClientOptions) + assert.Nil(t, gotClient) + assert.EqualError(t, err, "trying to cast pre-existing Email client for depencency injection") + }) +} diff --git a/internal/dependencyinjection/fixtures.go b/internal/dependencyinjection/fixtures.go new file mode 100644 index 000000000..ab8825a6a --- /dev/null +++ b/internal/dependencyinjection/fixtures.go @@ -0,0 +1,8 @@ +package dependencyinjection + +import "testing" + +func ClearInstancesTestHelper(t *testing.T) { + t.Helper() + dependenciesStoreMap = make(map[string]interface{}) +} diff --git a/internal/dependencyinjection/main.go b/internal/dependencyinjection/main.go new file mode 100644 index 000000000..f3763ebfa --- /dev/null +++ b/internal/dependencyinjection/main.go @@ -0,0 +1,9 @@ +package dependencyinjection + +// dependenciesStoreMap var is the global map for all the service instances. +var dependenciesStoreMap map[string]interface{} = map[string]interface{}{} + +// setInstance adds a new service instance to instances map. +func setInstance(instanceName string, instance interface{}) { + dependenciesStoreMap[instanceName] = instance +} diff --git a/internal/dependencyinjection/sms_client.go b/internal/dependencyinjection/sms_client.go new file mode 100644 index 000000000..4294e0747 --- /dev/null +++ b/internal/dependencyinjection/sms_client.go @@ -0,0 +1,53 @@ +package dependencyinjection + +import ( + "fmt" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" +) + +const SMSClientInstanceName = "sms_client_instance" + +type SMSClientOptions struct { + SMSType message.MessengerType + MessengerOptions *message.MessengerOptions +} + +// buildSMSClientInstanceName sets up a instance name for the SMS messenger type +// to either be created and stored, also retrived later, so we can have a instance +// for each type at the same time. +func buildSMSClientInstanceName(smsClientType message.MessengerType) string { + return fmt.Sprintf("%s-%s", SMSClientInstanceName, string(smsClientType)) +} + +// NewSMSClient creates a new SMS client instance, or retrives a instance that +// was already created before. +func NewSMSClient(opts SMSClientOptions) (message.MessengerClient, error) { + if !opts.SMSType.IsSMS() { + return nil, fmt.Errorf("trying to create a SMS client with a non-supported SMS type: %q", opts.SMSType) + } + + if opts.MessengerOptions == nil { + opts.MessengerOptions = &message.MessengerOptions{} + } + opts.MessengerOptions.MessengerType = opts.SMSType + + // If there is already an instance of the service, we return the same instance + instanceName := buildSMSClientInstanceName(opts.MessengerOptions.MessengerType) + if instance, ok := dependenciesStoreMap[instanceName]; ok { + if smsClientInstance, ok := instance.(message.MessengerClient); ok { + return smsClientInstance, nil + } + return nil, fmt.Errorf("trying to cast pre-existing SMS client for depencency injection") + } + + log.Infof("βš™οΈ Setting SMS client to: %v", opts.MessengerOptions.MessengerType) + messengerClient, err := message.GetClient(*opts.MessengerOptions) + if err != nil { + return nil, fmt.Errorf("creating SMS client: %w", err) + } + + setInstance(instanceName, messengerClient) + return messengerClient, nil +} diff --git a/internal/dependencyinjection/sms_client_test.go b/internal/dependencyinjection/sms_client_test.go new file mode 100644 index 000000000..6dc318b80 --- /dev/null +++ b/internal/dependencyinjection/sms_client_test.go @@ -0,0 +1,85 @@ +package dependencyinjection + +import ( + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewSMSClient(t *testing.T) { + t.Run("should return an error on a invalid SMS type", func(t *testing.T) { + defer ClearInstancesTestHelper(t) + + mockSMSClientOptions := SMSClientOptions{SMSType: "foo-bar"} + + gotClient, err := NewSMSClient(mockSMSClientOptions) + require.Nil(t, gotClient) + require.Error(t, err) + assert.EqualError(t, err, `trying to create a SMS client with a non-supported SMS type: "foo-bar"`) + }) + + t.Run("should return the same instance when called twice for the same SMS type", func(t *testing.T) { + defer ClearInstancesTestHelper(t) + + // STEP 1: assert that Twilio client should not be instantiated more than once + mockSMSClientOptions := SMSClientOptions{ + SMSType: message.MessengerTypeTwilioSMS, + MessengerOptions: &message.MessengerOptions{ + Environment: "dev", + TwilioAccountSID: "testtesttesttesttest", + TwilioAuthToken: "testtesttesttesttest", + TwilioServiceSID: "testtesttesttesttest", + }, + } + + twilioClient1, err := NewSMSClient(mockSMSClientOptions) + require.NoError(t, err) + twilioClient2, err := NewSMSClient(mockSMSClientOptions) + require.NoError(t, err) + assert.Equal(t, &twilioClient1, &twilioClient2) + + // STEP 2: assert that AWS sms client should not be instantiated more than once + mockSMSClientOptions = SMSClientOptions{ + SMSType: message.MessengerTypeAWSSMS, + MessengerOptions: &message.MessengerOptions{ + Environment: "dev", + AWSAccessKeyID: "testtesttesttesttesttest", + AWSSecretAccessKey: "testtesttesttesttesttest", + AWSRegion: "testtesttesttesttesttest", + }, + } + + awsClient1, err := NewSMSClient(mockSMSClientOptions) + require.NoError(t, err) + awsClient2, err := NewSMSClient(mockSMSClientOptions) + require.NoError(t, err) + assert.Equal(t, &awsClient1, &awsClient2) + + // STEP 3: assert that twilio and aws clients are different + assert.NotEqual(t, &twilioClient1, &awsClient1) + assert.NotEqual(t, twilioClient1, awsClient1) + }) + + t.Run("should return an error on a invalid pre-existing instance", func(t *testing.T) { + defer ClearInstancesTestHelper(t) + + mockTestSMSClientOptions := SMSClientOptions{ + SMSType: message.MessengerTypeTwilioSMS, + MessengerOptions: &message.MessengerOptions{ + Environment: "test", + TwilioAccountSID: "testtesttesttesttest", + TwilioAuthToken: "testtesttesttesttest", + TwilioServiceSID: "testtesttesttesttest", + }, + } + + preExistingSMSClientWithInvalidType := struct{}{} + setInstance(buildSMSClientInstanceName(message.MessengerTypeTwilioSMS), preExistingSMSClientWithInvalidType) + + gotClient, err := NewSMSClient(mockTestSMSClientOptions) + assert.Nil(t, gotClient) + assert.EqualError(t, err, "trying to cast pre-existing SMS client for depencency injection") + }) +} diff --git a/internal/htmltemplate/htmltemplate.go b/internal/htmltemplate/htmltemplate.go new file mode 100644 index 000000000..f6589573c --- /dev/null +++ b/internal/htmltemplate/htmltemplate.go @@ -0,0 +1,64 @@ +package htmltemplate + +import ( + "bytes" + "embed" + "fmt" + "text/template" +) + +//go:embed tmpl/*.tmpl +var Tmpl embed.FS + +func ExecuteHTMLTemplate(templateName string, data interface{}) (string, error) { + t, err := template.ParseFS(Tmpl, "tmpl/*.tmpl") + if err != nil { + return "", fmt.Errorf("error parsing embedded template files: %w", err) + } + + var executedTemplate bytes.Buffer + err = t.ExecuteTemplate(&executedTemplate, templateName, data) + if err != nil { + return "", fmt.Errorf("executing html template: %w", err) + } + + return executedTemplate.String(), nil +} + +type EmptyBodyEmailTemplate struct { + Body string +} + +func ExecuteHTMLTemplateForEmailEmptyBody(data EmptyBodyEmailTemplate) (string, error) { + return ExecuteHTMLTemplate("empty_body.tmpl", data) +} + +type InvitationMessageTemplate struct { + FirstName string + Role string + ForgotPasswordLink string + OrganizationName string +} + +func ExecuteHTMLTemplateForInvitationMessage(data InvitationMessageTemplate) (string, error) { + return ExecuteHTMLTemplate("invitation_message.tmpl", data) +} + +type ForgotPasswordMessageTemplate struct { + ResetToken string + ResetPasswordLink string + OrganizationName string +} + +func ExecuteHTMLTemplateForForgotPasswordMessage(data ForgotPasswordMessageTemplate) (string, error) { + return ExecuteHTMLTemplate("forgot_password_message.tmpl", data) +} + +type MFAMessageTemplate struct { + MFACode string + OrganizationName string +} + +func ExecuteHTMLTemplateForMFAMessage(data MFAMessageTemplate) (string, error) { + return ExecuteHTMLTemplate("mfa_message.tmpl", data) +} diff --git a/internal/htmltemplate/htmltemplate_test.go b/internal/htmltemplate/htmltemplate_test.go new file mode 100644 index 000000000..7fe0303ff --- /dev/null +++ b/internal/htmltemplate/htmltemplate_test.go @@ -0,0 +1,82 @@ +package htmltemplate + +import ( + "crypto/rand" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ExecuteHTMLTemplate(t *testing.T) { + // File not found + var inputData interface{} + templateStr, err := ExecuteHTMLTemplate("non-existing-file.html", inputData) + require.Empty(t, templateStr) + require.EqualError(t, err, `executing html template: template: no template "non-existing-file.html" associated with template "empty_body.tmpl"`) + + // handle invalid struct body + inputData = struct { + WrongFieldName string + }{ + WrongFieldName: "foo bar", + } + templateStr, err = ExecuteHTMLTemplate("empty_body.tmpl", inputData) + require.Empty(t, templateStr) + require.EqualError(t, err, `executing html template: template: empty_body.tmpl:9:2: executing "empty_body.tmpl" at <.Body>: can't evaluate field Body in type struct { WrongFieldName string }`) + + // Success πŸŽ‰ + inputData = EmptyBodyEmailTemplate{Body: "foo bar"} + + templateStr, err = ExecuteHTMLTemplate("empty_body.tmpl", inputData) + require.NoError(t, err) + require.Contains(t, templateStr, "\nfoo bar\n") +} + +func Test_ExecuteHTMLTemplateForEmailEmptyBody(t *testing.T) { + // create a random string: + randReader := rand.Reader + b := make([]byte, 10) + _, err := randReader.Read(b) + require.NoError(t, err) + randomStr := fmt.Sprintf("%x", b)[:10] + + // check if the random string is imprinted in the template + inputData := EmptyBodyEmailTemplate{Body: randomStr} + templateStr, err := ExecuteHTMLTemplateForEmailEmptyBody(inputData) + require.NoError(t, err) + require.Contains(t, templateStr, randomStr) +} + +func Test_ExecuteHTMLTemplateForInvitationMessage(t *testing.T) { + forgotPasswordLink := "https://sdp.com/forgot-password" + + data := InvitationMessageTemplate{ + FirstName: "First", + Role: "developer", + ForgotPasswordLink: forgotPasswordLink, + OrganizationName: "Organization Name", + } + content, err := ExecuteHTMLTemplateForInvitationMessage(data) + require.NoError(t, err) + + assert.Contains(t, content, "Hello, First!") + assert.Contains(t, content, "as a developer.") + assert.Contains(t, content, forgotPasswordLink) + assert.Contains(t, content, "Organization Name") +} + +func Test_ExecuteHTMLTemplateForForgotPasswordMessage(t *testing.T) { + data := ForgotPasswordMessageTemplate{ + ResetToken: "resetToken", + ResetPasswordLink: "https://sdp.com/reset-password", + OrganizationName: "Organization Name", + } + content, err := ExecuteHTMLTemplateForForgotPasswordMessage(data) + require.NoError(t, err) + + assert.Contains(t, content, "resetToken") + assert.Contains(t, content, "reset password page") + assert.Contains(t, content, "Organization Name") +} diff --git a/internal/htmltemplate/tmpl/empty_body.tmpl b/internal/htmltemplate/tmpl/empty_body.tmpl new file mode 100644 index 000000000..d442bfe06 --- /dev/null +++ b/internal/htmltemplate/tmpl/empty_body.tmpl @@ -0,0 +1,11 @@ + + + + + + + + +{{.Body}} + + \ No newline at end of file diff --git a/internal/htmltemplate/tmpl/forgot_password_message.tmpl b/internal/htmltemplate/tmpl/forgot_password_message.tmpl new file mode 100644 index 000000000..e27d2a571 --- /dev/null +++ b/internal/htmltemplate/tmpl/forgot_password_message.tmpl @@ -0,0 +1,27 @@ + + + + + Password Reset + + + +

Hello,

+

We received a request to reset your Stellar Disbursement Platform account password. Please use the confirmation token {{.ResetToken}} on the reset password page to create a new password.

+

If you did not request a password reset, please ignore this message or reach out to your organization's administrator with any questions or concerns.

+

Best regards,

+

The {{.OrganizationName}} Team

+ + diff --git a/internal/htmltemplate/tmpl/invitation_message.tmpl b/internal/htmltemplate/tmpl/invitation_message.tmpl new file mode 100644 index 000000000..3d106c831 --- /dev/null +++ b/internal/htmltemplate/tmpl/invitation_message.tmpl @@ -0,0 +1,41 @@ + + + + + Welcome to Stellar Disbursement Platform + + + +

Hello, {{.FirstName}}!

+

You have been added to your organization's Stellar Disbursement Platform as a {{.Role}}. Please click the link below to set up your password and let your organization administrator know if you have any questions.

+

+ Set up my password +

+

Best regards,

+

The {{.OrganizationName}} Team

+ + diff --git a/internal/htmltemplate/tmpl/mfa_message.tmpl b/internal/htmltemplate/tmpl/mfa_message.tmpl new file mode 100644 index 000000000..cc66a1855 --- /dev/null +++ b/internal/htmltemplate/tmpl/mfa_message.tmpl @@ -0,0 +1,24 @@ + + + + + Your verification code + + + +

Here is the 6-digit verification code you requested. Please enter it into the two-factor authentication prompt to sign-in.

+

Your verification code is: {{.MFACode}}.

+

If you did not request this code, please ignore this message and consider changing your password to ensure the security of your account. If you have any questions or concerns, please reach out to your organization's administrator.

+

Best regards,

+

The {{.OrganizationName}} Team

+ + diff --git a/internal/htmltemplate/tmpl/receiver_otp_message.tmpl b/internal/htmltemplate/tmpl/receiver_otp_message.tmpl new file mode 100644 index 000000000..f51974e25 --- /dev/null +++ b/internal/htmltemplate/tmpl/receiver_otp_message.tmpl @@ -0,0 +1 @@ +Here is your wallet verification code: {{.OTP}} diff --git a/internal/htmltemplate/tmpl/receiver_register.tmpl b/internal/htmltemplate/tmpl/receiver_register.tmpl new file mode 100644 index 000000000..2bba739b8 --- /dev/null +++ b/internal/htmltemplate/tmpl/receiver_register.tmpl @@ -0,0 +1,223 @@ + + + + + + Wallet Registration + + + + + + + + + + +
+ +
+
+

Enter your phone number to get verified

+ +

+ Enter your phone number below. If you are pre-approved, you will + receive a one-time passcode.
+ Include + sign and country code in your phone number. Do not enter + spaces or dashes. +

+ +
+
+ + +
+ +
+ +
+ +
+
+
+ + +
+ + +
+
+

Enter passcode

+ +

+ If you are pre-approved, you will receive a one-time passcode, enter + it below to continue. +

+ +

+ + Do not share your OTP or verification data with anyone. People who ask for this information could be trying to access your account. + +

+ +
+
+ + +
+
+ + +
+ +
+ +
+ + +
+
+
+ + + + +
+ + + {{.JWTToken}} +
+ + + + + + diff --git a/internal/htmltemplate/tmpl/receiver_registered_successfully.tmpl b/internal/htmltemplate/tmpl/receiver_registered_successfully.tmpl new file mode 100644 index 000000000..dc5d1eb9d --- /dev/null +++ b/internal/htmltemplate/tmpl/receiver_registered_successfully.tmpl @@ -0,0 +1,49 @@ + + + + + + Wallet Registration Confirmation + + + + + + + + + + +
+ +
+
+

Your information has been successfully verified!

+ +

+ Click the button below to be taken back to home and receive your + disbursement. +

+ +
+ +
+
+
+ + + {{.JWTToken}} +
+ + + + + diff --git a/internal/integrationtests/.env.example b/internal/integrationtests/.env.example new file mode 100644 index 000000000..3ca6d30a5 --- /dev/null +++ b/internal/integrationtests/.env.example @@ -0,0 +1,7 @@ +# Generate a new keypair for SEP-10 signing +SEP10_SIGNING_PUBLIC_KEY= +SEP10_SIGNING_PRIVATE_KEY= + +# Generate a new keypair for the distribution account +DISTRIBUTION_PUBLIC_KEY= +DISTRIBUTION_SEED= \ No newline at end of file diff --git a/internal/integrationtests/anchor_platform.go b/internal/integrationtests/anchor_platform.go new file mode 100644 index 000000000..e6704357e --- /dev/null +++ b/internal/integrationtests/anchor_platform.go @@ -0,0 +1,259 @@ +package integrationtests + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "mime/multipart" + "net/http" + "net/url" + "strings" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" +) + +const ( + authURL = "auth" + sep24DepositURL = "sep24/transactions/deposit/interactive" +) + +type AnchorPlatformIntegrationTestsInterface interface { + StartChallengeTransaction() (*ChallengeTransaction, error) + SignChallengeTransaction(challengeTx *ChallengeTransaction) (*SignedChallengeTransaction, error) + SendSignedChallengeTransaction(signedChallengeTx *SignedChallengeTransaction) (*AnchorPlatformAuthToken, error) + CreateSep24DepositTransaction(authToken *AnchorPlatformAuthToken) (*AnchorPlatformAuthSEP24Token, *AnchorPlatformDepositResponse, error) +} + +type AnchorPlatformIntegrationTests struct { + HttpClient httpclient.HttpClientInterface + AnchorPlatformBaseSepURL string + ReceiverAccountPublicKey string + ReceiverAccountPrivateKey string + Sep10SigningPublicKey string + DisbursedAssetCode string +} + +type ChallengeTransaction struct { + TransactionStr string `json:"transaction"` + NetworkPassphrase string `json:"network_passphrase"` + Transaction *txnbuild.Transaction +} + +type SignedChallengeTransaction struct { + *ChallengeTransaction + SignedTransaction *txnbuild.Transaction +} + +type AnchorPlatformAuthToken struct { + Token string `json:"token"` +} + +type AnchorPlatformDepositResponse struct { + URL string `json:"url"` + TransactionID string `json:"id"` +} + +type AnchorPlatformAuthSEP24Token struct { + Token string `query:"token"` +} + +// StartChallengeTransaction create a new challenge transaction through the anchor platform. +func (ap AnchorPlatformIntegrationTests) StartChallengeTransaction() (*ChallengeTransaction, error) { + authURL, err := url.JoinPath(ap.AnchorPlatformBaseSepURL, authURL) + if err != nil { + return nil, fmt.Errorf("error creating url: %w", err) + } + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authURL, nil) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + + // hosts domain to be used in txnbuild.ReadChallengeTx + homeDomain := "localhost:8080" + webAuthDomain := "localhost:8080" + + // create query params 'account' and 'home_domain' + q := req.URL.Query() + q.Add("account", ap.ReceiverAccountPublicKey) + q.Add("home_domain", homeDomain) + req.URL.RawQuery = q.Encode() + + resp, err := ap.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request to anchor platform get AUTH: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return nil, fmt.Errorf("error creating challenge transaction on anchor platform") + } + + ct := &ChallengeTransaction{} + err = json.NewDecoder(resp.Body).Decode(ct) + if err != nil { + return nil, fmt.Errorf("error decoding response body: %w", err) + } + + // read the challenge transaction created by the anchor platform and assign it to the ChallengeTransaction object. + tx, _, _, _, err := txnbuild.ReadChallengeTx( + ct.TransactionStr, + ap.Sep10SigningPublicKey, + ct.NetworkPassphrase, + webAuthDomain, + []string{homeDomain}, + ) + if err != nil { + return nil, fmt.Errorf("error reading challenge transaction: %w", err) + } + ct.Transaction = tx + + return ct, nil +} + +// SignChallengeTransaction signs a challenge transaction with the ReceiverAccountPrivateKey. +func (ap AnchorPlatformIntegrationTests) SignChallengeTransaction(challengeTx *ChallengeTransaction) (*SignedChallengeTransaction, error) { + // get the receiver account keypair + kp, err := keypair.ParseFull(ap.ReceiverAccountPrivateKey) + if err != nil { + return nil, fmt.Errorf("error getting receiver keypair: %w", err) + } + + // sign the challenge transaction with the receiver account keypair + st := &SignedChallengeTransaction{ChallengeTransaction: challengeTx} + signedTx, err := challengeTx.Transaction.Sign(challengeTx.NetworkPassphrase, kp) + if err != nil { + return nil, fmt.Errorf("error signing challenge transaction: %w", err) + } + + // attributes signedTx to the SignedChallengeTransaction object + st.SignedTransaction = signedTx + + return st, nil +} + +// SendSignedChallengeTransaction sends the signed transaction to the anchor platform to get the authorization token. +func (ap AnchorPlatformIntegrationTests) SendSignedChallengeTransaction(signedChallengeTx *SignedChallengeTransaction) (*AnchorPlatformAuthToken, error) { + authURL, err := url.JoinPath(ap.AnchorPlatformBaseSepURL, authURL) + if err != nil { + return nil, fmt.Errorf("error creating url: %w", err) + } + + // get the transaction object in base 64 format + txBase64, err := signedChallengeTx.SignedTransaction.Base64() + if err != nil { + return nil, fmt.Errorf("error converting signed transaction to base 64: %w", err) + } + // sets transaction base 64 in request body + data := url.Values{} + data.Set("transaction", txBase64) + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + + // POST auth endpoint on anchor platform expects the content-type to be x-www-form-urlencoded + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := ap.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request to anchor platform post AUTH: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return nil, fmt.Errorf("error sending signed challenge transaction on anchor platform") + } + + at := &AnchorPlatformAuthToken{} + err = json.NewDecoder(resp.Body).Decode(at) + if err != nil { + return nil, fmt.Errorf("error decoding response body: %w", err) + } + + return at, nil +} + +// CreateSep24DepositTransaction creates a new sep24 deposit transaction on the anchor platform. +// To make this request, an auth token is required and it needs to be obtained through SEP-10. +func (ap AnchorPlatformIntegrationTests) CreateSep24DepositTransaction(authToken *AnchorPlatformAuthToken) (*AnchorPlatformAuthSEP24Token, *AnchorPlatformDepositResponse, error) { + depositUrl, err := url.JoinPath(ap.AnchorPlatformBaseSepURL, sep24DepositURL) + if err != nil { + return nil, nil, fmt.Errorf("error creating url: %w", err) + } + + // creates the multipart/form-data with the necessary fields to complete the request on the anchor platform + b := &bytes.Buffer{} + w := multipart.NewWriter(b) + defer w.Close() + formValues := map[string]string{ + "asset_code": ap.DisbursedAssetCode, + "account": ap.ReceiverAccountPublicKey, + "lang": "en", + "claimable_balance_supported": "false", + } + for k, v := range formValues { + err = w.WriteField(k, v) + if err != nil { + return nil, nil, fmt.Errorf("error writing %q field to form data: %w", k, err) + } + } + // we need to close *multipart.Writter before pass as parameter in http.NewRequestWithContext + w.Close() + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, depositUrl, b) + if err != nil { + return nil, nil, fmt.Errorf("error creating new request: %w", err) + } + + // POST sep24/transactions/deposit/interactive endpoint on anchor platform expects the content-type to be multipart/form-data + req.Header.Set("Content-Type", w.FormDataContentType()) + // sets in the header the authorization token received in SendSignedChallengeTransaction + req.Header.Set("Authorization", "Bearer "+authToken.Token) + + resp, err := ap.HttpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("error making request to anchor platform post SEP24 Deposit: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return nil, nil, fmt.Errorf("error creating sep24 deposit transaction on anchor platform") + } + + dr := &AnchorPlatformDepositResponse{} + err = json.NewDecoder(resp.Body).Decode(dr) + if err != nil { + return nil, nil, fmt.Errorf("error decoding response body: %w", err) + } + + registerURL, err := url.Parse(dr.URL) + if err != nil { + return nil, nil, fmt.Errorf("error parsing url from AnchorPlatformDepositResponse: %w", err) + } + + queryParams, err := url.ParseQuery(registerURL.RawQuery) + if err != nil { + return nil, nil, fmt.Errorf("error parsing query params from register url: %w", err) + } + + if _, ok := queryParams["token"]; !ok { + return nil, nil, fmt.Errorf("error register url not have a valid token") + } + + at := &AnchorPlatformAuthSEP24Token{ + Token: queryParams.Get("token"), + } + + return at, dr, nil +} + +// Ensuring that AnchorPlatformIntegrationTests is implementing AnchorPlatformIntegrationTestsInterface. +var _ AnchorPlatformIntegrationTestsInterface = (*AnchorPlatformIntegrationTests)(nil) diff --git a/internal/integrationtests/anchor_platform_test.go b/internal/integrationtests/anchor_platform_test.go new file mode 100644 index 000000000..32ee934a6 --- /dev/null +++ b/internal/integrationtests/anchor_platform_test.go @@ -0,0 +1,443 @@ +package integrationtests + +import ( + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_StartChallengeTransaction(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + receiverAccountID := "GDJNLIFC2JTGKTD4LA4D77TSEGMQLZKBIEXMMJT64AEWWVYC5JKJHH2X" + + serverPublicKey := "GD57H5NAK3NFZVR66OGPBAV4FUFUQXQTPOQTSKFLW63SVQVQ4FSQAXMA" + serverPrivateKey := "SBG2NGVW7VYIZDK4R775UXNRZUODJBS3N3H6ICKKAAMXUSWBOHUXETE4" + + ap := AnchorPlatformIntegrationTests{ + HttpClient: &httpClientMock, + AnchorPlatformBaseSepURL: "http://mock_anchor.com/", + ReceiverAccountPublicKey: receiverAccountID, + Sep10SigningPublicKey: serverPublicKey, + } + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + ct, err := ap.StartChallengeTransaction() + require.EqualError(t, err, "error making request to anchor platform get AUTH: error calling the request") + assert.Empty(t, ct) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to create challenge transaction on anchor platform", func(t *testing.T) { + transactionResponse := `{Error creating challenge transaction}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + ct, err := ap.StartChallengeTransaction() + require.EqualError(t, err, "error creating challenge transaction on anchor platform") + assert.Empty(t, ct) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid response body", func(t *testing.T) { + transactionResponse := `` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + ct, err := ap.StartChallengeTransaction() + require.EqualError(t, err, "error decoding response body: EOF") + assert.Empty(t, ct) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error reading challenge transaction on anchor platform", func(t *testing.T) { + invalidServerPrivateKey := "SAB4UJB2NCL5SUJUBDNOVAN6ILOULGDB3G6TTBZ32TERX2N454ORSUIY" + mockCT, err := txnbuild.BuildChallengeTx( + invalidServerPrivateKey, + receiverAccountID, + "localhost:8080", + "localhost:8080", + "Test SDF Network ; September 2015", + time.Second*300, + nil, + ) + require.NoError(t, err) + + transactionStr, err := mockCT.Base64() + require.NoError(t, err) + + transactionResponse := fmt.Sprintf(`{ + "transaction": %q, + "network_passphrase": "Test SDF Network ; September 2015" + }`, transactionStr) + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + ct, err := ap.StartChallengeTransaction() + require.EqualError(t, err, "error reading challenge transaction: transaction source account is not equal to server's account") + assert.Empty(t, ct) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully creating challenge transaction on anchor platform", func(t *testing.T) { + mockCT, err := txnbuild.BuildChallengeTx( + serverPrivateKey, + receiverAccountID, + "localhost:8080", + "localhost:8080", + "Test SDF Network ; September 2015", + time.Second*300, + nil, + ) + require.NoError(t, err) + + transactionStr, err := mockCT.Base64() + require.NoError(t, err) + + transactionResponse := fmt.Sprintf(`{ + "transaction": %q, + "network_passphrase": "Test SDF Network ; September 2015" + }`, transactionStr) + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + ct, err := ap.StartChallengeTransaction() + require.NoError(t, err) + + assert.Equal(t, "Test SDF Network ; September 2015", ct.NetworkPassphrase) + assert.Equal(t, transactionStr, ct.TransactionStr) + assert.Equal(t, mockCT, ct.Transaction) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_SignChallengeTransaction(t *testing.T) { + receiverPrivateKey := "SCKUFVMBEBE7NCJPJ6DIURH5ECC5ORPYPNG46YFQWCTECEPF35QK4XTO" + receiverAccountID := "GDJNLIFC2JTGKTD4LA4D77TSEGMQLZKBIEXMMJT64AEWWVYC5JKJHH2X" + + serverPrivateKey := "SBG2NGVW7VYIZDK4R775UXNRZUODJBS3N3H6ICKKAAMXUSWBOHUXETE4" + + mockCT, err := txnbuild.BuildChallengeTx( + serverPrivateKey, + receiverAccountID, + "mock_anchor.com", + "mock_anchor.com", + "Test SDF Network ; September 2015", + time.Second*300, + nil, + ) + require.NoError(t, err) + + transactionStr, err := mockCT.Base64() + require.NoError(t, err) + + ct := &ChallengeTransaction{ + TransactionStr: transactionStr, + Transaction: mockCT, + NetworkPassphrase: "Test SDF Network ; September 2015", + } + + t.Run("error getting stellar keypair", func(t *testing.T) { + ap := AnchorPlatformIntegrationTests{ + ReceiverAccountPrivateKey: "invalid private key", + } + st, err := ap.SignChallengeTransaction(ct) + require.EqualError(t, err, "error getting receiver keypair: non-canonical strkey; unused leftover character") + assert.Empty(t, st) + }) + + t.Run("signing challenge transaction", func(t *testing.T) { + ap := AnchorPlatformIntegrationTests{ + ReceiverAccountPrivateKey: receiverPrivateKey, + } + st, err := ap.SignChallengeTransaction(ct) + require.NoError(t, err) + + assert.Equal(t, "Test SDF Network ; September 2015", st.NetworkPassphrase) + assert.Equal(t, transactionStr, st.TransactionStr) + assert.Equal(t, mockCT, st.Transaction) + + kp, err := keypair.ParseFull(receiverPrivateKey) + require.NoError(t, err) + signedTx, err := mockCT.Sign("Test SDF Network ; September 2015", kp) + require.NoError(t, err) + + assert.Equal(t, signedTx, st.SignedTransaction) + }) +} + +func Test_SendSignedChallengeTransaction(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + ap := AnchorPlatformIntegrationTests{ + HttpClient: &httpClientMock, + AnchorPlatformBaseSepURL: "http://mock_anchor.com/", + } + + receiverPrivateKey := "SCKUFVMBEBE7NCJPJ6DIURH5ECC5ORPYPNG46YFQWCTECEPF35QK4XTO" + receiverAccountID := "GDJNLIFC2JTGKTD4LA4D77TSEGMQLZKBIEXMMJT64AEWWVYC5JKJHH2X" + + serverPrivateKey := "SBG2NGVW7VYIZDK4R775UXNRZUODJBS3N3H6ICKKAAMXUSWBOHUXETE4" + + mockCT, err := txnbuild.BuildChallengeTx( + serverPrivateKey, + receiverAccountID, + "mock_anchor.com", + "mock_anchor.com", + "Test SDF Network ; September 2015", + time.Second*300, + nil, + ) + require.NoError(t, err) + + transactionStr, err := mockCT.Base64() + require.NoError(t, err) + + kp, err := keypair.ParseFull(receiverPrivateKey) + require.NoError(t, err) + signedTx, err := mockCT.Sign("Test SDF Network ; September 2015", kp) + require.NoError(t, err) + + st := &SignedChallengeTransaction{ + ChallengeTransaction: &ChallengeTransaction{ + TransactionStr: transactionStr, + Transaction: mockCT, + NetworkPassphrase: "Test SDF Network ; September 2015", + }, + SignedTransaction: signedTx, + } + + t.Run("error converting signed transaction to base 64", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + + at, err := ap.SendSignedChallengeTransaction(st) + require.EqualError(t, err, "error making request to anchor platform post AUTH: error calling the request") + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to send signed challenge transaction on anchor platform", func(t *testing.T) { + transactionResponse := `{Error sending signed challenge transaction}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, err := ap.SendSignedChallengeTransaction(st) + require.EqualError(t, err, "error sending signed challenge transaction on anchor platform") + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid response body", func(t *testing.T) { + transactionResponse := `` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, err := ap.SendSignedChallengeTransaction(st) + require.EqualError(t, err, "error decoding response body: EOF") + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully sending signed challenge transaction on anchor platform", func(t *testing.T) { + authToken := "valid token" + + transactionResponse := fmt.Sprintf(`{ + "token": %q + }`, authToken) + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(transactionResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, err := ap.SendSignedChallengeTransaction(st) + require.NoError(t, err) + + assert.Equal(t, authToken, at.Token) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_CreateSep24DepositTransaction(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + receiverAccountID := "GDJNLIFC2JTGKTD4LA4D77TSEGMQLZKBIEXMMJT64AEWWVYC5JKJHH2X" + + ap := AnchorPlatformIntegrationTests{ + HttpClient: &httpClientMock, + AnchorPlatformBaseSepURL: "http://mock_anchor.com/", + ReceiverAccountPublicKey: receiverAccountID, + DisbursedAssetCode: "USDC", + } + + at := &AnchorPlatformAuthToken{ + Token: "valid token", + } + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.EqualError(t, err, "error making request to anchor platform post SEP24 Deposit: error calling the request") + assert.Empty(t, dr) + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to create sep24 deposit transaction on anchor platform", func(t *testing.T) { + depositResponse := `{Error creating sep24 deposit transaction}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(depositResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.EqualError(t, err, "error creating sep24 deposit transaction on anchor platform") + assert.Empty(t, dr) + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid response body", func(t *testing.T) { + depositResponse := `` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(depositResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.EqualError(t, err, "error decoding response body: EOF") + assert.Empty(t, dr) + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid url in response body", func(t *testing.T) { + depositResponse := `{ + "id": "mock_id", + "url": "%" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(depositResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.EqualError(t, err, "error parsing url from AnchorPlatformDepositResponse: parse \"%\": invalid URL escape \"%\"") + assert.Empty(t, dr) + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid query params in url from response body", func(t *testing.T) { + depositResponse := `{ + "id": "mock_id", + "url": "http://mock_registration_url.com?q=%" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(depositResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.EqualError(t, err, "error parsing query params from register url: invalid URL escape \"%\"") + assert.Empty(t, dr) + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error url from response body missing token", func(t *testing.T) { + depositResponse := `{ + "id": "mock_id", + "url": "http://mock_registration_url.com" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(depositResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.EqualError(t, err, "error register url not have a valid token") + assert.Empty(t, dr) + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully creating sep24 deposit transaction on anchor platform", func(t *testing.T) { + depositResponse := `{ + "id": "mock_id", + "url": "http://mock_registration_url.com?token=valid_token" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(depositResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, dr, err := ap.CreateSep24DepositTransaction(at) + require.NoError(t, err) + + assert.Equal(t, "mock_id", dr.TransactionID) + assert.Equal(t, "http://mock_registration_url.com?token=valid_token", dr.URL) + assert.Equal(t, "valid_token", at.Token) + + httpClientMock.AssertExpectations(t) + }) +} diff --git a/internal/integrationtests/docker-compose-e2e-tests.yml b/internal/integrationtests/docker-compose-e2e-tests.yml new file mode 100644 index 000000000..735a135ed --- /dev/null +++ b/internal/integrationtests/docker-compose-e2e-tests.yml @@ -0,0 +1,194 @@ +version: '3.8' +services: + db: + container_name: e2e-sdp-v2-database + image: postgres:14-alpine + environment: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: e2e-sdp + PGDATA: /data/postgres + ports: + - "5432:5432" + volumes: + - e2e-postgres-db:/data/postgres + + sdp-api: + container_name: e2e-sdp-api + image: stellar/sdp-v2:latest + build: + context: ../../ + dockerfile: Dockerfile + ports: + - "8000:8000" + environment: + BASE_URL: http://localhost:8000 + DATABASE_URL: postgres://postgres@db:5432/e2e-sdp?sslmode=disable + ENVIRONMENT: localhost + LOG_LEVEL: TRACE + PORT: "8000" + METRICS_PORT: "8002" + METRICS_TYPE: PROMETHEUS + EMAIL_SENDER_TYPE: DRY_RUN + SMS_SENDER_TYPE: DRY_RUN + NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + EC256_PUBLIC_KEY: "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEJ3HNphPAEKHvtRjsl5Kjwc9tTMqS\n2pmYNybrLsxZ6cuQvg2yiEoXZixP2cJ77csHClXC6cb1wQp/BNGDvGKoPg==\n-----END PUBLIC KEY-----" + SEP10_SIGNING_PUBLIC_KEY: ${SEP10_SIGNING_PUBLIC_KEY} + ANCHOR_PLATFORM_BASE_SEP_URL: http://anchor-platform:8080 + ANCHOR_PLATFORM_BASE_PLATFORM_URL: http://anchor-platform:8085 + DISTRIBUTION_PUBLIC_KEY: ${DISTRIBUTION_PUBLIC_KEY} + RECAPTCHA_SITE_KEY: 6LeIxAcTAAAAAJcZVRqyHh71UMIEGNQ_MXjiZKhI + CORS_ALLOWED_ORIGINS: "*" + ENABLE_MFA: "false" + ENABLE_RECAPTCHA: "false" + + # integration tests vars + USER_EMAIL: ${USER_EMAIL} + USER_PASSWORD: ${USER_PASSWORD} + DISBURSED_ASSET_CODE: USDC + DISBURSED_ASSET_ISSUER: GDKLFXO3FL25I7ST632KMMBP5D72QGTDV55TOWUB2XG2O67NNQDKYMLG + RECEIVER_ACCOUNT_PUBLIC_KEY: GCDYFAJSZPH3RCXL6NWMMOY54CXNUBYFTDCBW7GGG6VPBW3WSDKSB2NU + RECEIVER_ACCOUNT_PRIVATE_KEY: SDSAVUWVNOFG2JEHKIWEUHAYIA6PLGEHLMHX2TMVKEQGZKOFQ7XXKDFE + DISBURSEMENT_CSV_FILE_PATH: files + DISBURSEMENT_CSV_FILE_NAME: disbursement_integration_tests.csv + SERVER_API_BASE_URL: http://localhost:8000 + + # secrets: + AWS_ACCESS_KEY_ID: MY_AWS_ACCESS_KEY_ID + AWS_REGION: MY_AWS_REGION + AWS_SECRET_ACCESS_KEY: MY_AWS_SECRET_ACCESS_KEY + AWS_SES_SENDER_ID: MY_AWS_SES_SENDER_ID + TWILIO_ACCOUNT_SID: MY_TWILIO_ACCOUNT_SID + TWILIO_AUTH_TOKEN: MY_TWILIO_AUTH_TOKEN + TWILIO_SERVICE_SID: MY_TWILIO_SERVICE_SID + EC256_PRIVATE_KEY: "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgdo6o+tdFkF94B7z8\nnoybH6/zO3PryLLjLbj54/zOi4WhRANCAAQncc2mE8AQoe+1GOyXkqPBz21MypLa\nmZg3JusuzFnpy5C+DbKIShdmLE/ZwnvtywcKVcLpxvXBCn8E0YO8Yqg+\n-----END PRIVATE KEY-----" + SEP10_SIGNING_PRIVATE_KEY: ${SEP10_SIGNING_PRIVATE_KEY} + SEP24_JWT_SECRET: jwt_secret_1234567890 + RECAPTCHA_SITE_SECRET_KEY: 6LeIxAcTAAAAAGG-vFI1TnRWxMZNFuojJ4WifJWe + ANCHOR_PLATFORM_OUTGOING_JWT_SECRET: mySdpToAnchorPlatformSecret + DISTRIBUTION_SEED: ${DISTRIBUTION_SEED} + entrypoint: "" + command: + - sh + - -c + - | + sleep 5 + ./stellar-disbursement-platform db migrate up + ./stellar-disbursement-platform db auth migrate up + ./stellar-disbursement-platform db setup-for-network + ./stellar-disbursement-platform serve + depends_on: + - db + + tss: + container_name: e2e-sdp-tss + image: stellar/sdp-v2:latest + build: + context: ../../ + dockerfile: Dockerfile + ports: + - "9000:9000" + environment: + DATABASE_URL: postgres://postgres@db:5432/e2e-sdp?sslmode=disable + NETWORK_PASSPHRASE: "Test SDF Network ; September 2015" + HORIZON_URL: "https://horizon-testnet.stellar.org" + NUM_CHANNEL_ACCOUNTS: "1" + MAX_BASE_FEE: "100" + MOCK: "false" + TSS_METRICS_PORT: "9002" + TSS_METRICS_TYPE: "TSS_PROMETHEUS" + DISTRIBUTION_SEED: ${DISTRIBUTION_SEED} + depends_on: + - db + - sdp-api + entrypoint: "" + command: + - sh + - -c + - | + sleep 30 + ./stellar-disbursement-platform channel-accounts verify --delete-invalid-accounts && + ./stellar-disbursement-platform channel-accounts ensure --num-channel-accounts-ensure 1 + ./stellar-disbursement-platform tss + + db-anchor-platform: + container_name: e2e-anchor-platform-postgres-db + image: postgres:14-alpine + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: e2e-postgres + PGPORT: 5433 + ports: + - "5433:5433" + volumes: + - e2e-postgres-ap-db:/data/postgres + + anchor-platform: + container_name: e2e-anchor-platform + image: stellar/anchor-platform:2.1.3 + command: --sep-server --platform-server --platform linux/amd64 + ports: + - "8080:8080" # sep-server + - "8085:8085" # platform-server + - "8082:8082" # metrics + depends_on: + - db-anchor-platform + environment: + HOST_URL: http://localhost:8080 + SEP_SERVER_PORT: 8080 + CALLBACK_API_BASE_URL: http://sdp-api:8000 + CALLBACK_API_AUTH_TYPE: none # TODO: update to jwt later + PLATFORM_SERVER_AUTH_TYPE: JWT + APP_LOGGING_LEVEL: INFO + DATA_TYPE: postgres + DATA_SERVER: db-anchor-platform:5433 + DATA_DATABASE: e2e-postgres + DATA_FLYWAY_ENABLED: "true" + DATA_DDL_AUTO: update + METRICS_ENABLED: "false" # Metrics would be available at port 8082 + METRICS_EXTRAS_ENABLED: "false" + SEP10_ENABLED: "true" + SEP10_HOME_DOMAIN: localhost:8080 + SEP24_ENABLED: "true" + SEP24_INTERACTIVE_URL_BASE_URL: http://sdp-api:8000/wallet-registration/start + SEP24_INTERACTIVE_URL_JWT_EXPIRATION: 1800 # 1800 seconds is 30 minutes + SEP24_MORE_INFO_URL_BASE_URL: http://sdp-api:8000/wallet-registration/start + SEP1_ENABLED: "true" + SEP1_TOML_TYPE: url + SEP1_TOML_VALUE: http://sdp-api:8000/.well-known/stellar.toml + ASSETS_TYPE: json + ASSETS_VALUE: | + { + "assets": [ + { + "sep24_enabled": true, + "schema": "stellar", + "code": "USDC", + "issuer": "GDKLFXO3FL25I7ST632KMMBP5D72QGTDV55TOWUB2XG2O67NNQDKYMLG", + "distribution_account": "${DISTRIBUTION_PUBLIC_KEY}", + "significant_decimals": 7, + "deposit": { + "enabled": true, + "fee_minimum": 0, + "fee_percent": 0, + "min_amount": 1, + "max_amount": 10000 + }, + "withdraw": {"enabled": false} + } + ] + } + + # secrets: + SECRET_DATA_USERNAME: postgres + SECRET_DATA_PASSWORD: postgres + SECRET_PLATFORM_API_AUTH_SECRET: mySdpToAnchorPlatformSecret + SECRET_SEP10_JWT_SECRET: jwt_secret_1234567890 + SECRET_SEP10_SIGNING_SEED: ${SEP10_SIGNING_PRIVATE_KEY} + SECRET_SEP24_INTERACTIVE_URL_JWT_SECRET: jwt_secret_1234567890 + SECRET_SEP24_MORE_INFO_URL_JWT_SECRET: jwt_secret_1234567890 +volumes: + e2e-postgres-db: + driver: local + e2e-postgres-ap-db: + driver: local \ No newline at end of file diff --git a/internal/integrationtests/e2e_integration_test.sh b/internal/integrationtests/e2e_integration_test.sh new file mode 100755 index 000000000..2834f3511 --- /dev/null +++ b/internal/integrationtests/e2e_integration_test.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# This script is used to run e2e integration tests locally with all necessary steps. +set -eu + +export DIVIDER="----------------------------------------" + +# prepare +echo "====> πŸ‘€Step 1: start preparation" +docker container ps -aq -f name='e2e' --format '{{.ID}}' | xargs docker stop | xargs docker rm -v && +docker volume ls -f name='e2e' --format '{{.Name}}' | xargs docker volume rm +echo "====> βœ…Step 1: finish preparation" + +# Run docker compose +echo $DIVIDER +echo "====> πŸ‘€Step 2: build sdp-api, anchor-platform and tss" +docker-compose -f docker-compose-e2e-tests.yml up --build -d +sleep 10 +echo "====> βœ…Step 2: finishing build" + +# Create new auth user +echo $DIVIDER +echo "====> πŸ‘€Step 3: create a new auth user on SDP API" +docker exec e2e-sdp-api bash -c "echo '$USER_PASSWORD' | ./stellar-disbursement-platform auth add-user '$USER_EMAIL' joe yabuki --password --owner --roles owner" +echo "====> βœ…Step 3: finish creating new auth user" + +# Create integration test data +echo $DIVIDER +echo "====> πŸ‘€Step 4: create new asset and test wallet on database" +docker exec e2e-sdp-api bash -c "./stellar-disbursement-platform integration-tests create-data" +echo "====> βœ…Step 4: finish creating integration test data" + +# Restart anchor platform container +echo $DIVIDER +echo "====> πŸ‘€Step 5: restart anchor platform container to get the new created asset" +docker restart e2e-anchor-platform +echo "waiting for anchor platform to initialize" +sleep 120 +echo "====> βœ…Step 5: finish restarting anchor platform container" + +# Run integration tests +echo $DIVIDER +echo "====> πŸ‘€Step 6: run integration tests command" +docker exec e2e-sdp-api bash -c "./stellar-disbursement-platform integration-tests start" +echo "====> βœ…Step 6: finish running integration test data" + +# Cleanup container and volumes +echo $DIVIDER +echo "====> πŸ‘€Step 7: cleaning up e2e containers and volumes" +docker container ps -aq -f name='e2e' --format '{{.ID}}' | xargs docker stop | xargs docker rm -v && +docker volume ls -f name='e2e' --format '{{.Name}}' | xargs docker volume rm +echo "====> βœ…Step 7: finish cleaning up containers and volumes" + +echo $DIVIDER +echo "πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰ SUCCESS! πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰" diff --git a/internal/integrationtests/files/disbursement_integration_tests.csv b/internal/integrationtests/files/disbursement_integration_tests.csv new file mode 100644 index 000000000..516c157e8 --- /dev/null +++ b/internal/integrationtests/files/disbursement_integration_tests.csv @@ -0,0 +1,2 @@ +phone,id,amount,verification ++12025550191,1,0.1,1999-03-30 \ No newline at end of file diff --git a/internal/integrationtests/files/empty_csv_file.csv b/internal/integrationtests/files/empty_csv_file.csv new file mode 100644 index 000000000..e69de29bb diff --git a/internal/integrationtests/integration_tests.go b/internal/integrationtests/integration_tests.go new file mode 100644 index 000000000..477b537e6 --- /dev/null +++ b/internal/integrationtests/integration_tests.go @@ -0,0 +1,268 @@ +package integrationtests + +import ( + "context" + "fmt" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httphandler" + tss "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" +) + +const paymentProcessTimeMinutes = 3 + +type IntegrationTestsInterface interface { + StartIntegrationTests(ctx context.Context, opts IntegrationTestsOpts) error + CreateTestData(ctx context.Context, opts IntegrationTestsOpts) error +} + +type IntegrationTestsOpts struct { + DatabaseDSN string + UserEmail string + UserPassword string + DisbursedAssetCode string + DisbursetAssetIssuer string + WalletName string + WalletHomepage string + WalletDeepLink string + WalletSEP10Domain string + DisbursementName string + DisbursementCSVFilePath string + DisbursementCSVFileName string + ReceiverAccountPublicKey string + ReceiverAccountPrivateKey string + ReceiverAccountStellarMemo string + Sep10SigningPublicKey string + RecaptchaSiteKey string + AnchorPlatformBaseSepURL string + ServerApiBaseURL string +} + +type IntegrationTestsService struct { + models *data.Models + dbConnectionPool db.DBConnectionPool + serverAPI ServerApiIntegrationTestsInterface + anchorPlatform AnchorPlatformIntegrationTestsInterface + horizonClient horizonclient.ClientInterface +} + +// NewIntegrationTestsService is a function that create a new IntegrationTestsService instance. +func NewIntegrationTestsService(opts IntegrationTestsOpts) (*IntegrationTestsService, error) { + // initialize dbConnection and data.Models + dbConnectionPool, err := db.OpenDBConnectionPool(opts.DatabaseDSN) + if err != nil { + return nil, fmt.Errorf("error connecting to the database: %w", err) + } + + models, err := data.NewModels(dbConnectionPool) + if err != nil { + return nil, fmt.Errorf("error creating models for integration tests: %w", err) + } + + return &IntegrationTestsService{ + models: models, + dbConnectionPool: dbConnectionPool, + }, nil +} + +func (it *IntegrationTestsService) initServices(ctx context.Context, opts IntegrationTestsOpts) { + // initialize default testnet horizon client + it.horizonClient = horizonclient.DefaultTestNetClient + + // initialize anchor platform integration tests service + it.anchorPlatform = &AnchorPlatformIntegrationTests{ + HttpClient: httpclient.DefaultClient(), + AnchorPlatformBaseSepURL: opts.AnchorPlatformBaseSepURL, + ReceiverAccountPublicKey: opts.ReceiverAccountPublicKey, + ReceiverAccountPrivateKey: opts.ReceiverAccountPrivateKey, + Sep10SigningPublicKey: opts.Sep10SigningPublicKey, + DisbursedAssetCode: opts.DisbursedAssetCode, + } + + // initialize server api integration tests service + it.serverAPI = &ServerApiIntegrationTests{ + HttpClient: httpclient.DefaultClient(), + ServerApiBaseURL: opts.ServerApiBaseURL, + UserEmail: opts.UserEmail, + UserPassword: opts.UserPassword, + DisbursementCSVFilePath: opts.DisbursementCSVFilePath, + DisbursementCSVFileName: opts.DisbursementCSVFileName, + } +} + +func (it *IntegrationTestsService) StartIntegrationTests(ctx context.Context, opts IntegrationTestsOpts) error { + log.Ctx(ctx).Info("Starting integration tests ......") + + it.initServices(ctx, opts) + log.Ctx(ctx).Info("Login user to get server API auth token") + authToken, err := it.serverAPI.Login(ctx) + if err != nil { + return fmt.Errorf("error trying to login in server API: %w", err) + } + log.Ctx(ctx).Info("User logged in") + log.Ctx(ctx).Info(authToken) + + log.Ctx(ctx).Info("Getting test asset in database") + asset, err := it.models.Assets.GetByCodeAndIssuer(ctx, opts.DisbursedAssetCode, opts.DisbursetAssetIssuer) + if err != nil { + return fmt.Errorf("error getting test asset: %w", err) + } + + log.Ctx(ctx).Info("Getting test wallet in database") + wallet, err := it.models.Wallets.GetByWalletName(ctx, opts.WalletName) + if err != nil { + return fmt.Errorf("error getting test wallet: %w", err) + } + + log.Ctx(ctx).Info("Creating disbursement using server API") + disbursement, err := it.serverAPI.CreateDisbursement(ctx, authToken, &httphandler.PostDisbursementRequest{ + Name: opts.DisbursementName, + CountryCode: "USA", + WalletID: wallet.ID, + AssetID: asset.ID, + }) + if err != nil { + return fmt.Errorf("error creating disbursement: %w", err) + } + log.Ctx(ctx).Info("Disbursement created") + + log.Ctx(ctx).Info("Processing disbursement CSV file using server API") + err = it.serverAPI.ProcessDisbursement(ctx, authToken, disbursement.ID) + if err != nil { + return fmt.Errorf("error processing disbursement: %w", err) + } + log.Ctx(ctx).Info("CSV disbursement file processed") + + log.Ctx(ctx).Info("Validating disbursement data after processing the disbursement file") + err = validateExpectationsAfterProcessDisbursement(ctx, disbursement.ID, it.models, it.dbConnectionPool) + if err != nil { + return fmt.Errorf("error validating data after process disbursement: %w", err) + } + log.Ctx(ctx).Info("Disbursement data validated") + + log.Ctx(ctx).Info("Starting disbursement using server API") + err = it.serverAPI.StartDisbursement(ctx, authToken, disbursement.ID, &httphandler.PatchDisbursementStatusRequest{Status: "STARTED"}) + if err != nil { + return fmt.Errorf("error starting disbursement: %w", err) + } + log.Ctx(ctx).Info("Disbursement started") + + log.Ctx(ctx).Info("Validating disbursement data after starting disbursement using server API") + err = validateExpectationsAfterStartDisbursement(ctx, disbursement.ID, it.models, it.dbConnectionPool) + if err != nil { + return fmt.Errorf("error validating data after process disbursement: %w", err) + } + log.Ctx(ctx).Info("Disbursement data validated") + + log.Ctx(ctx).Info("Starting anchor platform integration ......") + log.Ctx(ctx).Info("Starting challenge transaction on anchor platform") + challengeTx, err := it.anchorPlatform.StartChallengeTransaction() + if err != nil { + return fmt.Errorf("error creating SEP10 challenge transaction: %w", err) + } + log.Ctx(ctx).Info("Challenge transaction created") + + log.Ctx(ctx).Info("Signing challenge transaction with Sep10SigningKey") + signedTx, err := it.anchorPlatform.SignChallengeTransaction(challengeTx) + if err != nil { + return fmt.Errorf("error signing SEP10 challenge transaction: %w", err) + } + log.Ctx(ctx).Info("Challenge transaction signed") + + log.Ctx(ctx).Info("Sending challenge transaction to anchor platform") + authSEP10Token, err := it.anchorPlatform.SendSignedChallengeTransaction(signedTx) + if err != nil { + return fmt.Errorf("error sending SEP10 challenge transaction: %w", err) + } + log.Ctx(ctx).Info("Received authSEP10Token") + + log.Ctx(ctx).Info("Creating SEP24 deposit transaction on anchor platform") + authSEP24Token, _, err := it.anchorPlatform.CreateSep24DepositTransaction(authSEP10Token) + if err != nil { + return fmt.Errorf("error creating SEP24 deposit transaction: %w", err) + } + log.Ctx(ctx).Info("Received authSEP24Token") + + disbursementData, err := readDisbursementCSV(opts.DisbursementCSVFilePath, opts.DisbursementCSVFileName) + if err != nil { + return fmt.Errorf("error reading disbursement CSV: %w", err) + } + + log.Ctx(ctx).Info("Completing receiver registration using server API") + err = it.serverAPI.ReceiverRegistration(ctx, authSEP24Token, &data.ReceiverRegistrationRequest{ + OTP: data.TestnetAlwaysValidOTP, + PhoneNumber: disbursementData[0].Phone, + VerificationValue: disbursementData[0].VerificationValue, + VerificationType: disbursement.VerificationField, + ReCAPTCHAToken: opts.RecaptchaSiteKey, + }) + if err != nil { + return fmt.Errorf("error registring receiver: %w", err) + } + log.Ctx(ctx).Info("Receiver OTP obtained") + + log.Ctx(ctx).Info("Validating receiver data after completing registration") + err = validateExpectationsAfterReceiverRegistration(ctx, it.models, opts.ReceiverAccountPublicKey, opts.ReceiverAccountStellarMemo) + if err != nil { + return fmt.Errorf("error validating receiver after registration: %w", err) + } + log.Ctx(ctx).Info("Receiver data validated") + + log.Ctx(ctx).Info("Waiting for payment to be processed by TSS") + time.Sleep(paymentProcessTimeMinutes * time.Minute) + + log.Ctx(ctx).Info("Querying database to get disbursement receiver with payment data") + receivers, err := it.models.DisbursementReceivers.GetAll(ctx, it.dbConnectionPool, &data.QueryParams{}, disbursement.ID) + if err != nil { + return fmt.Errorf("error getting receivers: %w", err) + } + + payment := receivers[0].Payment + q := `SELECT * FROM submitter_transactions WHERE external_id = $1` + var tx tss.Transaction + err = it.dbConnectionPool.GetContext(ctx, &tx, q, payment.ID) + if err != nil { + return fmt.Errorf("getting TSS transaction from database: %w", err) + } + log.Ctx(ctx).Infof("TSS transaction: %+v", tx) + + log.Ctx(ctx).Info("Getting payment from disbursement receiver") + if payment.Status != data.SuccessPaymentStatus || payment.StellarTransactionID == "" { + return fmt.Errorf("payment was not processed successfully by TSS: %+v", payment) + } + + log.Ctx(ctx).Info("Payment was successfully updated by the TSS") + log.Ctx(ctx).Info("Validating transaction on Horizon Network") + ph, err := getTransactionOnHorizon(it.horizonClient, payment.StellarTransactionID) + if err != nil { + return fmt.Errorf("error getting transaction on horizon network: %w", err) + } + err = validateStellarTransaction(ph, opts.ReceiverAccountPublicKey, opts.DisbursedAssetCode, opts.DisbursetAssetIssuer, receivers[0].Payment.Amount) + if err != nil { + return fmt.Errorf("error validating stellar transaction: %w", err) + } + log.Ctx(ctx).Info("Transaction validated") + + log.Ctx(ctx).Info("πŸŽ‰πŸŽ‰πŸŽ‰Finishing integration tests, the receiver was successfully funded πŸŽ‰πŸŽ‰πŸŽ‰") + + return nil +} + +func (it *IntegrationTestsService) CreateTestData(ctx context.Context, opts IntegrationTestsOpts) error { + _, err := it.models.Assets.GetOrCreate(ctx, opts.DisbursedAssetCode, opts.DisbursetAssetIssuer) + if err != nil { + return fmt.Errorf("error getting or creating test asset: %w", err) + } + + _, err = it.models.Wallets.GetOrCreate(ctx, opts.WalletName, opts.WalletHomepage, opts.WalletDeepLink, opts.WalletSEP10Domain) + if err != nil { + return fmt.Errorf("error getting or creating test wallet: %w", err) + } + + return nil +} diff --git a/internal/integrationtests/main.go b/internal/integrationtests/main.go new file mode 100644 index 000000000..5884bbcef --- /dev/null +++ b/internal/integrationtests/main.go @@ -0,0 +1,6 @@ +package integrationtests + +import "embed" + +//go:embed files/* +var DisbursementCSVFiles embed.FS diff --git a/internal/integrationtests/server_api.go b/internal/integrationtests/server_api.go new file mode 100644 index 000000000..53ba8cb5d --- /dev/null +++ b/internal/integrationtests/server_api.go @@ -0,0 +1,257 @@ +package integrationtests + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/fs" + "mime/multipart" + "net/http" + "net/url" + "path" + "strings" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httphandler" +) + +const ( + loginURL = "login" + disbursementURL = "disbursements" + registrationURL = "wallet-registration" +) + +type ServerApiIntegrationTestsInterface interface { + Login(ctx context.Context) (*ServerApiAuthToken, error) + CreateDisbursement(ctx context.Context, authToken *ServerApiAuthToken, body *httphandler.PostDisbursementRequest) (*data.Disbursement, error) + ProcessDisbursement(ctx context.Context, authToken *ServerApiAuthToken, disbursementID string) error + StartDisbursement(ctx context.Context, authToken *ServerApiAuthToken, disbursementID string, body *httphandler.PatchDisbursementStatusRequest) error + ReceiverRegistration(ctx context.Context, authSEP24Token *AnchorPlatformAuthSEP24Token, body *data.ReceiverRegistrationRequest) error +} + +type ServerApiIntegrationTests struct { + HttpClient httpclient.HttpClientInterface + ServerApiBaseURL string + UserEmail string + UserPassword string + DisbursementCSVFilePath string + DisbursementCSVFileName string +} + +type ServerApiAuthToken struct { + Token string `json:"token"` +} + +// Login login the integration test user on SDP server API. +func (sa *ServerApiIntegrationTests) Login(ctx context.Context) (*ServerApiAuthToken, error) { + reqURL, err := url.JoinPath(sa.ServerApiBaseURL, loginURL) + if err != nil { + return nil, fmt.Errorf("error creating url: %w", err) + } + + reqBody, err := json.Marshal(&httphandler.LoginRequest{ + Email: sa.UserEmail, + Password: sa.UserPassword, + }) + if err != nil { + return nil, fmt.Errorf("error creating json post body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(string(reqBody))) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := sa.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request to server API post LOGIN: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return nil, fmt.Errorf("error trying to login on the server API") + } + + at := &ServerApiAuthToken{} + err = json.NewDecoder(resp.Body).Decode(at) + if err != nil { + return nil, fmt.Errorf("error decoding response body: %w", err) + } + + return at, nil +} + +// CreateDisbursement creates a new disbursement using the SDP server API. +func (sa *ServerApiIntegrationTests) CreateDisbursement(ctx context.Context, authToken *ServerApiAuthToken, body *httphandler.PostDisbursementRequest) (*data.Disbursement, error) { + reqURL, err := url.JoinPath(sa.ServerApiBaseURL, disbursementURL) + if err != nil { + return nil, fmt.Errorf("error creating url: %w", err) + } + + reqBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("error creating json post body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(string(reqBody))) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken.Token) + + resp, err := sa.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request to server API post DISBURSEMENT: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return nil, fmt.Errorf("error trying to create a new disbursement on the server API") + } + + disbursement := &data.Disbursement{} + err = json.NewDecoder(resp.Body).Decode(disbursement) + if err != nil { + return nil, fmt.Errorf("error decoding response body: %w", err) + } + + return disbursement, nil +} + +// createInstructionsRequest creates the request with multipart formdata to process disbursement on SDP server API. +func createInstructionsRequest(ctx context.Context, reqURL, disbursementCSVFilePath, disbursementCSVFileName string) (*http.Request, error) { + filePath := path.Join(disbursementCSVFilePath, disbursementCSVFileName) + + csvBytes, err := fs.ReadFile(DisbursementCSVFiles, filePath) + if err != nil { + return nil, fmt.Errorf("error reading csv file: %w", err) + } + + b := &bytes.Buffer{} + w := multipart.NewWriter(b) + defer w.Close() + + fileWriter, err := w.CreateFormFile("file", disbursementCSVFileName) + if err != nil { + return nil, fmt.Errorf("error creating form file with disbursement csv file: %w", err) + } + + _, err = io.Copy(fileWriter, bytes.NewReader(csvBytes)) + if err != nil { + return nil, fmt.Errorf("error copying file: %w", err) + } + // we need to close *multipart.Writter before pass as parameter in http.NewRequestWithContext + w.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, b) + if err != nil { + return nil, fmt.Errorf("error creating new request: %w", err) + } + + req.Header.Set("Content-Type", w.FormDataContentType()) + + return req, nil +} + +// ProcessDisbursement process the disbursement using the SDP server API. +func (sa *ServerApiIntegrationTests) ProcessDisbursement(ctx context.Context, authToken *ServerApiAuthToken, disbursementID string) error { + reqURL, err := url.JoinPath(sa.ServerApiBaseURL, disbursementURL, disbursementID, "instructions") + if err != nil { + return fmt.Errorf("error creating url: %w", err) + } + + req, err := createInstructionsRequest(ctx, reqURL, sa.DisbursementCSVFilePath, sa.DisbursementCSVFileName) + if err != nil { + return fmt.Errorf("error creating instructions request with multipart form-data: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+authToken.Token) + + resp, err := sa.HttpClient.Do(req) + if err != nil { + return fmt.Errorf("error making request to server API post DISBURSEMENT INSTRUCTIONS: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return fmt.Errorf("error trying to process the disbursement CSV file on the server API") + } + + return nil +} + +// StartDisbursement starts the disbursement using the SDP server API. +func (sa *ServerApiIntegrationTests) StartDisbursement(ctx context.Context, authToken *ServerApiAuthToken, disbursementID string, body *httphandler.PatchDisbursementStatusRequest) error { + reqURL, err := url.JoinPath(sa.ServerApiBaseURL, disbursementURL, disbursementID, "status") + if err != nil { + return fmt.Errorf("error creating url: %w", err) + } + + reqBody, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("error creating json post body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, reqURL, strings.NewReader(string(reqBody))) + if err != nil { + return fmt.Errorf("error creating new request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken.Token) + + resp, err := sa.HttpClient.Do(req) + if err != nil { + return fmt.Errorf("error making request to server API patch DISBURSEMENT: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return fmt.Errorf("error trying to start the disbursement on the server API") + } + + return nil +} + +// ReceiverRegistration completes the receiver registration using SDP server API and the anchor platform. +func (sa *ServerApiIntegrationTests) ReceiverRegistration(ctx context.Context, authSEP24Token *AnchorPlatformAuthSEP24Token, body *data.ReceiverRegistrationRequest) error { + reqURL, err := url.JoinPath(sa.ServerApiBaseURL, registrationURL, "verification") + if err != nil { + return fmt.Errorf("error creating url: %w", err) + } + + reqBody, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("error creating json post body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(string(reqBody))) + if err != nil { + return fmt.Errorf("error creating new request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authSEP24Token.Token) + + resp, err := sa.HttpClient.Do(req) + if err != nil { + return fmt.Errorf("error making request to server API post WALLET REGISTRATION VERIFICATION: %w", err) + } + + if resp.StatusCode/100 != 2 { + logErrorResponses(ctx, resp.Body) + return fmt.Errorf("error trying to complete receiver registration on the server API") + } + + return nil +} + +// Ensuring that ServerApiIntegrationTests is implementing ServerApiIntegrationTestsInterface. +var _ ServerApiIntegrationTestsInterface = (*ServerApiIntegrationTests)(nil) diff --git a/internal/integrationtests/server_api_test.go b/internal/integrationtests/server_api_test.go new file mode 100644 index 000000000..c46079fd7 --- /dev/null +++ b/internal/integrationtests/server_api_test.go @@ -0,0 +1,353 @@ +package integrationtests + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httphandler" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_Login(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + sa := ServerApiIntegrationTests{ + HttpClient: &httpClientMock, + ServerApiBaseURL: "http://mock_server.com/", + UserEmail: "user_mock@email.com", + UserPassword: "userPass123", + } + + ctx := context.Background() + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + at, err := sa.Login(ctx) + require.EqualError(t, err, "error making request to server API post LOGIN: error calling the request") + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to login on server api", func(t *testing.T) { + loginResponse := `{Invalid credentials.}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(loginResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, err := sa.Login(ctx) + require.EqualError(t, err, "error trying to login on the server API") + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid response body", func(t *testing.T) { + loginResponse := `` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(loginResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, err := sa.Login(ctx) + require.EqualError(t, err, "error decoding response body: EOF") + assert.Empty(t, at) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully logging on server api", func(t *testing.T) { + loginResponse := `{ + "token": "valid_token" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(loginResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + at, err := sa.Login(ctx) + require.NoError(t, err) + + assert.Equal(t, "valid_token", at.Token) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_CreateDisbursement(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + sa := ServerApiIntegrationTests{ + HttpClient: &httpClientMock, + ServerApiBaseURL: "http://mock_server.com/", + } + + ctx := context.Background() + + authToken := &ServerApiAuthToken{ + Token: "valid_token", + } + + reqBody := &httphandler.PostDisbursementRequest{ + Name: "mockDisbursement", + CountryCode: "USA", + WalletID: "123", + AssetID: "890", + } + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + d, err := sa.CreateDisbursement(ctx, authToken, reqBody) + require.EqualError(t, err, "error making request to server API post DISBURSEMENT: error calling the request") + assert.Empty(t, d) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to create a disbursement on server api", func(t *testing.T) { + disbursementResponse := `{Invalid credentials.}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + d, err := sa.CreateDisbursement(ctx, authToken, reqBody) + require.EqualError(t, err, "error trying to create a new disbursement on the server API") + assert.Empty(t, d) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error invalid response body", func(t *testing.T) { + disbursementResponse := `` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + d, err := sa.CreateDisbursement(ctx, authToken, reqBody) + require.EqualError(t, err, "error decoding response body: EOF") + assert.Empty(t, d) + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully creating a disbursement on server api", func(t *testing.T) { + disbursementResponse := `{ + "id": "619da857-8725-4c58-933d-c120a458e0f5", + "name": "mockDisbursement", + "status": "DRAFT" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + d, err := sa.CreateDisbursement(ctx, authToken, reqBody) + require.NoError(t, err) + + assert.Equal(t, "mockDisbursement", d.Name) + assert.Equal(t, "619da857-8725-4c58-933d-c120a458e0f5", d.ID) + assert.Equal(t, "DRAFT", string(d.Status)) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_ProcessDisbursement(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + sa := ServerApiIntegrationTests{ + HttpClient: &httpClientMock, + ServerApiBaseURL: "http://mock_server.com/", + DisbursementCSVFilePath: "files", + DisbursementCSVFileName: "disbursement_integration_tests.csv", + } + + ctx := context.Background() + + authToken := &ServerApiAuthToken{ + Token: "valid_token", + } + + mockDisbursementID := "disbursement_id" + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + err := sa.ProcessDisbursement(ctx, authToken, mockDisbursementID) + require.EqualError(t, err, "error making request to server API post DISBURSEMENT INSTRUCTIONS: error calling the request") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to process the disbursement on server api", func(t *testing.T) { + disbursementResponse := `{error processing disbursement.}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + err := sa.ProcessDisbursement(ctx, authToken, mockDisbursementID) + require.EqualError(t, err, "error trying to process the disbursement CSV file on the server API") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully creating a disbursement on server api", func(t *testing.T) { + disbursementResponse := `{ + "message": "File uploaded successfully" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + err := sa.ProcessDisbursement(ctx, authToken, mockDisbursementID) + require.NoError(t, err) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_StartDisbursement(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + sa := ServerApiIntegrationTests{ + HttpClient: &httpClientMock, + ServerApiBaseURL: "http://mock_server.com/", + } + + ctx := context.Background() + + authToken := &ServerApiAuthToken{ + Token: "valid_token", + } + + mockDisbursementID := "disbursement_id" + reqBody := &httphandler.PatchDisbursementStatusRequest{ + Status: "STARTED", + } + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + err := sa.StartDisbursement(ctx, authToken, mockDisbursementID, reqBody) + require.EqualError(t, err, "error making request to server API patch DISBURSEMENT: error calling the request") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to start the disbursement on server api", func(t *testing.T) { + disbursementResponse := `{error starting disbursement.}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + err := sa.StartDisbursement(ctx, authToken, mockDisbursementID, reqBody) + require.EqualError(t, err, "error trying to start the disbursement on the server API") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully creating a disbursement on server api", func(t *testing.T) { + disbursementResponse := `{ + "message": "Disbursement started" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + err := sa.StartDisbursement(ctx, authToken, mockDisbursementID, reqBody) + require.NoError(t, err) + + httpClientMock.AssertExpectations(t) + }) +} + +func Test_ReceiverRegistration(t *testing.T) { + httpClientMock := httpclient.HttpClientMock{} + + sa := ServerApiIntegrationTests{ + HttpClient: &httpClientMock, + ServerApiBaseURL: "http://mock_server.com/", + } + + ctx := context.Background() + + authToken := &AnchorPlatformAuthSEP24Token{ + Token: "valid_token", + } + + reqBody := &data.ReceiverRegistrationRequest{ + PhoneNumber: "+18554212274", + OTP: "123456", + VerificationValue: "1999-01-30", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "captchtoken", + } + + t.Run("error calling httpClient.Do", func(t *testing.T) { + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(nil, fmt.Errorf("error calling the request")).Once() + err := sa.ReceiverRegistration(ctx, authToken, reqBody) + require.EqualError(t, err, "error making request to server API post WALLET REGISTRATION VERIFICATION: error calling the request") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("error trying to registrate receiver on server api", func(t *testing.T) { + disbursementResponse := `{error registrating receiver.}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusBadRequest, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + err := sa.ReceiverRegistration(ctx, authToken, reqBody) + + require.EqualError(t, err, "error trying to complete receiver registration on the server API") + + httpClientMock.AssertExpectations(t) + }) + + t.Run("succesfully registrating receiver on server api", func(t *testing.T) { + disbursementResponse := `{ + "message": "ok" + }` + + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(disbursementResponse)), + StatusCode: http.StatusOK, + } + httpClientMock.On("Do", mock.AnythingOfType("*http.Request")).Return(response, nil).Once() + + err := sa.ReceiverRegistration(ctx, authToken, reqBody) + + require.NoError(t, err) + + httpClientMock.AssertExpectations(t) + }) +} diff --git a/internal/integrationtests/utils.go b/internal/integrationtests/utils.go new file mode 100644 index 000000000..9148741e0 --- /dev/null +++ b/internal/integrationtests/utils.go @@ -0,0 +1,65 @@ +package integrationtests + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/fs" + "path" + + "github.com/gocarina/gocsv" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +// logErrorResponses logs the response body for requests with error status. +func logErrorResponses(ctx context.Context, body io.ReadCloser) { + respBody, err := io.ReadAll(body) + if err == nil { + log.Ctx(ctx).Infof("error message response: %s", string(respBody)) + } +} + +func readDisbursementCSV(disbursementFilePath string, disbursementFileName string) ([]*data.DisbursementInstruction, error) { + filePath := path.Join(disbursementFilePath, disbursementFileName) + + csvBytes, err := fs.ReadFile(DisbursementCSVFiles, filePath) + if err != nil { + return nil, fmt.Errorf("error reading csv file: %w", err) + } + + instructions := []*data.DisbursementInstruction{} + if err = gocsv.UnmarshalBytes(csvBytes, &instructions); err != nil { + return nil, fmt.Errorf("error parsing csv file: %w", err) + } + + return instructions, nil +} + +type PaymentHorizon struct { + ReceiverAccount string `json:"to"` + Amount string `json:"amount"` + AssetCode string `json:"asset_code"` + AssetIssuer string `json:"asset_issuer"` + TransactionSuccessful bool `json:"transaction_successful"` +} + +func getTransactionOnHorizon(client horizonclient.ClientInterface, transactionID string) (*PaymentHorizon, error) { + ph := &PaymentHorizon{} + records, err := client.Payments(horizonclient.OperationRequest{ForTransaction: transactionID}) + if err != nil { + return nil, fmt.Errorf("error checking payment in horizon: %w", err) + } + paymentRecord, err := json.Marshal(records.Embedded.Records[0]) + if err != nil { + return nil, fmt.Errorf("error marshaling payment record: %w", err) + } + err = json.Unmarshal(paymentRecord, ph) + if err != nil { + return nil, fmt.Errorf("error unmarshling payment record: %w", err) + } + + return ph, nil +} diff --git a/internal/integrationtests/utils_test.go b/internal/integrationtests/utils_test.go new file mode 100644 index 000000000..1318cdd8e --- /dev/null +++ b/internal/integrationtests/utils_test.go @@ -0,0 +1,144 @@ +package integrationtests + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "path" + "strings" + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon/operations" + + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/problem" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_logErrorResponses(t *testing.T) { + body := `{error response body}` + response := &http.Response{ + Body: io.NopCloser(strings.NewReader(body)), + } + ctx := context.Background() + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.InfoLevel) + + logErrorResponses(ctx, response.Body) + + require.Contains(t, buf.String(), `level=info msg="error message response: {error response body}`) +} + +func Test_readDisbursementCSV(t *testing.T) { + t.Run("error trying read csv file", func(t *testing.T) { + filePath := path.Join("files", "invalid_file.csv") + expectedError := fmt.Sprintf("error reading csv file: open %s: file does not exist", filePath) + + data, err := readDisbursementCSV("files", "invalid_file.csv") + require.EqualError(t, err, expectedError) + assert.Empty(t, data) + }) + + t.Run("error opening empty csv file", func(t *testing.T) { + data, err := readDisbursementCSV("files", "empty_csv_file.csv") + require.EqualError(t, err, "error parsing csv file: empty csv file given") + assert.Empty(t, data) + }) + + t.Run("reading csv file", func(t *testing.T) { + data, err := readDisbursementCSV("files", "disbursement_integration_tests.csv") + require.NoError(t, err) + assert.Equal(t, data[0].Amount, "0.1") + assert.Equal(t, data[0].Phone, "+12025550191") + assert.Equal(t, data[0].ID, "1") + assert.Equal(t, data[0].VerificationValue, "1999-03-30") + }) +} + +func Test_getTransactionInHorizon(t *testing.T) { + mockHorizonClient := &horizonclient.MockClient{} + mockTransactionID := "transactionID" + + t.Run("error trying to get transaction on horizon", func(t *testing.T) { + mockHorizonClient. + On("Payments", horizonclient.OperationRequest{ForTransaction: mockTransactionID}). + Return(operations.OperationsPage{}, horizonclient.Error{ + Problem: problem.NotFound, + }). + Once() + + ph, err := getTransactionOnHorizon(mockHorizonClient, mockTransactionID) + require.EqualError(t, err, "error checking payment in horizon: horizon error: \"Resource Missing\" - check horizon.Error.Problem for more information") + assert.Empty(t, ph) + + mockHorizonClient.AssertExpectations(t) + }) + + horizonResponse := `{ + "_embedded": { + "records": [ + { + "_links": { + "self": { + "href": "" + }, + "transaction": { + "href": "" + }, + "effects": { + "href": "" + }, + "succeeds": { + "href": "" + }, + "precedes": { + "href": "" + } + }, + "id": "123456", + "paging_token": "67890", + "transaction_successful": true, + "source_account": "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB", + "type": "payment", + "type_i": 1, + "created_at": "2023-06-15T14:01:59Z", + "transaction_hash": "17qw02bb7aaa949e9a852b48176e64dae381f4ce20af454b5f4d405ce67wsad1", + "asset_type": "credit_alphanum4", + "asset_code": "USDC", + "asset_issuer": "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB", + "from": "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB", + "to": "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7", + "amount": "100.0000000" + } + ] + } + } + ` + var paymentPage operations.OperationsPage + + err := json.Unmarshal([]byte(horizonResponse), &paymentPage) + require.NoError(t, err) + + t.Run("successful get transaction on horizon", func(t *testing.T) { + mockHorizonClient. + On("Payments", horizonclient.OperationRequest{ForTransaction: mockTransactionID}). + Return(paymentPage, nil). + Once() + + ph, err := getTransactionOnHorizon(mockHorizonClient, mockTransactionID) + require.NoError(t, err) + assert.Equal(t, "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7", ph.ReceiverAccount) + assert.Equal(t, "USDC", ph.AssetCode) + assert.Equal(t, "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB", ph.AssetIssuer) + assert.Equal(t, "100.0000000", ph.Amount) + assert.Equal(t, true, ph.TransactionSuccessful) + + mockHorizonClient.AssertExpectations(t) + }) +} diff --git a/internal/integrationtests/validations.go b/internal/integrationtests/validations.go new file mode 100644 index 000000000..5adc1582f --- /dev/null +++ b/internal/integrationtests/validations.go @@ -0,0 +1,106 @@ +package integrationtests + +import ( + "context" + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +func validateExpectationsAfterProcessDisbursement(ctx context.Context, disbursementID string, models *data.Models, sqlExec db.SQLExecuter) error { + disbursement, err := models.Disbursements.Get(ctx, sqlExec, disbursementID) + if err != nil { + return fmt.Errorf("error getting disbursement: %w", err) + } + + if disbursement.Status != data.ReadyDisbursementStatus { + return fmt.Errorf("invalid status for disbursement after process disbursement") + } + + receivers, err := models.DisbursementReceivers.GetAll(ctx, sqlExec, &data.QueryParams{}, disbursementID) + if err != nil { + return fmt.Errorf("error getting receivers from disbursement: %w", err) + } + + if len(receivers) <= 0 { + return fmt.Errorf("error getting receivers from disbursement: receivers not found") + } + receiver := receivers[0] + + // TODO upgrade this function to validate multiples receiver wallets and payments. + if receiver.ReceiverWallet.Status != data.DraftReceiversWalletStatus { + return fmt.Errorf("invalid status for receiver_wallet after process disbursement") + } + + if receiver.Payment.Status != data.DraftPaymentStatus { + return fmt.Errorf("invalid status for payment after process disbursement") + } + + return nil +} + +func validateExpectationsAfterStartDisbursement(ctx context.Context, disbursementID string, models *data.Models, sqlExec db.SQLExecuter) error { + disbursement, err := models.Disbursements.Get(ctx, sqlExec, disbursementID) + if err != nil { + return fmt.Errorf("error getting disbursement: %w", err) + } + + if disbursement.Status != data.StartedDisbursementStatus { + return fmt.Errorf("invalid status for disbursement after start disbursement") + } + + receivers, err := models.DisbursementReceivers.GetAll(ctx, sqlExec, &data.QueryParams{}, disbursementID) + if err != nil { + return fmt.Errorf("error getting receivers from disbursement: %w", err) + } + if len(receivers) <= 0 { + return fmt.Errorf("error getting receivers from disbursement: receivers not found") + } + + receiver := receivers[0] + + // TODO upgrade this function to validate multiples receiver wallets and payments. + if receiver.ReceiverWallet.Status != data.ReadyReceiversWalletStatus { + return fmt.Errorf("invalid status for receiver_wallet after start disbursement") + } + + if receiver.Payment.Status != data.ReadyPaymentStatus { + return fmt.Errorf("invalid status for payment after start disbursement") + } + + return nil +} + +func validateExpectationsAfterReceiverRegistration(ctx context.Context, models *data.Models, stellarAccount, stellarMemo string) error { + receiverWallet, err := models.ReceiverWallet.GetByStellarAccountAndMemo(ctx, stellarAccount, stellarMemo) + if err != nil { + return fmt.Errorf("error getting receiver wallet with stellar account: %w", err) + } + + if receiverWallet.Status != data.RegisteredReceiversWalletStatus { + return fmt.Errorf("invalid status for receiver_wallet after receiver registration") + } + + return nil +} + +func validateStellarTransaction(paymentHorizon *PaymentHorizon, receiverAccount, disbursedAssetCode, disbursedAssetIssuer, amount string) error { + if !paymentHorizon.TransactionSuccessful { + return fmt.Errorf("transaction was not successful on horizon network") + } + + if paymentHorizon.ReceiverAccount != receiverAccount { + return fmt.Errorf("transaction sent to wrong receiver account") + } + + if paymentHorizon.Amount != amount { + return fmt.Errorf("transaction with wrong amount") + } + + if paymentHorizon.AssetCode != disbursedAssetCode || paymentHorizon.AssetIssuer != disbursedAssetIssuer { + return fmt.Errorf("transaction with wrong disbursed asset") + } + + return nil +} diff --git a/internal/integrationtests/validations_test.go b/internal/integrationtests/validations_test.go new file mode 100644 index 000000000..e272c7ece --- /dev/null +++ b/internal/integrationtests/validations_test.go @@ -0,0 +1,325 @@ +package integrationtests + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/require" +) + +func Test_validationAfterProcessDisbursement(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("disbursement not found", func(t *testing.T) { + err = validateExpectationsAfterProcessDisbursement(ctx, "invalid_id", models, dbConnectionPool) + require.EqualError(t, err, "error getting disbursement: record not found") + }) + + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + t.Run("invalid disbursement status", func(t *testing.T) { + invalidDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "Invalid Disbursement", + Status: data.CompletedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + err = validateExpectationsAfterProcessDisbursement(ctx, invalidDisbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "invalid status for disbursement after process disbursement") + }) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + t.Run("disbursement receivers not found", func(t *testing.T) { + err = validateExpectationsAfterProcessDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "error getting receivers from disbursement: receivers not found") + }) + + t.Run("invalid receiver wallet status", func(t *testing.T) { + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.FlaggedReceiversWalletStatus) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.DraftPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = validateExpectationsAfterProcessDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "invalid status for receiver_wallet after process disbursement") + }) + + t.Run("invalid payment status", func(t *testing.T) { + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.FailedPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = validateExpectationsAfterProcessDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "invalid status for payment after process disbursement") + }) + + t.Run("successfull validation", func(t *testing.T) { + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.DraftPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = validateExpectationsAfterProcessDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.NoError(t, err) + }) +} + +func Test_validationAfterStartDisbursement(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("disbursement not found", func(t *testing.T) { + err = validateExpectationsAfterStartDisbursement(ctx, "invalid_id", models, dbConnectionPool) + require.EqualError(t, err, "error getting disbursement: record not found") + }) + + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + t.Run("invalid disbursement status", func(t *testing.T) { + invalidDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "Invalid Disbursement", + Status: data.CompletedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + err = validateExpectationsAfterStartDisbursement(ctx, invalidDisbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "invalid status for disbursement after start disbursement") + }) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + t.Run("disbursement receivers not found", func(t *testing.T) { + err = validateExpectationsAfterStartDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "error getting receivers from disbursement: receivers not found") + }) + + t.Run("invalid receiver wallet status", func(t *testing.T) { + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.FlaggedReceiversWalletStatus) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.DraftPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = validateExpectationsAfterStartDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "invalid status for receiver_wallet after start disbursement") + }) + + t.Run("invalid payment status", func(t *testing.T) { + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.FailedPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = validateExpectationsAfterStartDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.EqualError(t, err, "invalid status for payment after start disbursement") + }) + + t.Run("successfull validation", func(t *testing.T) { + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.ReadyPaymentStatus, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + err = validateExpectationsAfterStartDisbursement(ctx, disbursement.ID, models, dbConnectionPool) + require.NoError(t, err) + }) +} + +func Test_validationAfterReceiverRegistration(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("receiver wallet not found", func(t *testing.T) { + err = validateExpectationsAfterReceiverRegistration(ctx, models, "invalid_stellar_account", "invalid_stellar_memo") + require.EqualError(t, err, "error getting receiver wallet with stellar account: no receiver wallet could be found in GetByStellarAccountAndMemo: record not found") + }) + + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + t.Run("invalid receiver wallet status validation", func(t *testing.T) { + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + err = validateExpectationsAfterReceiverRegistration(ctx, models, receiverWallet.StellarAddress, receiverWallet.StellarMemo) + require.EqualError(t, err, "invalid status for receiver_wallet after receiver registration") + }) + + t.Run("successfull validation", func(t *testing.T) { + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + err = validateExpectationsAfterReceiverRegistration(ctx, models, receiverWallet.StellarAddress, receiverWallet.StellarMemo) + require.NoError(t, err) + }) +} + +func Test_validateStellarTransaction(t *testing.T) { + mockReceiverAccount := "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7" + mockassetCode := "USDC" + mockassetIssuer := "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB" + mockAmount := "0.1" + + t.Run("error transaction not successful", func(t *testing.T) { + err := validateStellarTransaction(&PaymentHorizon{ + TransactionSuccessful: false, + }, mockReceiverAccount, mockassetCode, mockassetIssuer, mockAmount) + require.EqualError(t, err, "transaction was not successful on horizon network") + }) + + t.Run("error wrong receiver account", func(t *testing.T) { + err := validateStellarTransaction(&PaymentHorizon{ + TransactionSuccessful: true, + ReceiverAccount: "invalidReceiver", + }, mockReceiverAccount, mockassetCode, mockassetIssuer, mockAmount) + require.EqualError(t, err, "transaction sent to wrong receiver account") + }) + + t.Run("error wrong amount", func(t *testing.T) { + err := validateStellarTransaction(&PaymentHorizon{ + TransactionSuccessful: true, + ReceiverAccount: "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7", + Amount: "20", + }, mockReceiverAccount, mockassetCode, mockassetIssuer, mockAmount) + require.EqualError(t, err, "transaction with wrong amount") + }) + + t.Run("error wrong asset code", func(t *testing.T) { + err := validateStellarTransaction(&PaymentHorizon{ + TransactionSuccessful: true, + ReceiverAccount: "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7", + Amount: "0.1", + AssetCode: "invalidCode", + AssetIssuer: "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB", + }, mockReceiverAccount, mockassetCode, mockassetIssuer, mockAmount) + require.EqualError(t, err, "transaction with wrong disbursed asset") + }) + + t.Run("error wrong asset issuer", func(t *testing.T) { + err := validateStellarTransaction(&PaymentHorizon{ + TransactionSuccessful: true, + ReceiverAccount: "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7", + Amount: "0.1", + AssetCode: "USDC", + AssetIssuer: "invalidIssuer", + }, mockReceiverAccount, mockassetCode, mockassetIssuer, mockAmount) + require.EqualError(t, err, "transaction with wrong disbursed asset") + }) + + t.Run("successful validation", func(t *testing.T) { + err := validateStellarTransaction(&PaymentHorizon{ + TransactionSuccessful: true, + ReceiverAccount: "GD44L3Q6NYRFPVOX4CJUUV63QEOOU3R5JNQJBLR6WWXFWYHEGK2YVBQ7", + Amount: "0.1", + AssetCode: "USDC", + AssetIssuer: "GBZF7AS3TBASAL5RQ7ECJODFWFLBDCKJK5SMPUCO5R36CJUIZRWQJTGB", + }, mockReceiverAccount, mockassetCode, mockassetIssuer, mockAmount) + require.NoError(t, err) + }) +} diff --git a/internal/message/aws_ses_client.go b/internal/message/aws_ses_client.go new file mode 100644 index 000000000..6d75bbfb7 --- /dev/null +++ b/internal/message/aws_ses_client.go @@ -0,0 +1,117 @@ +package message + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +// awsSESInterface is used to send emails. +type awsSESInterface interface { + SendEmail(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) +} + +// awsSESClient is used to send emails. +type awsSESClient struct { + emailService awsSESInterface + senderID string +} + +func (t *awsSESClient) MessengerType() MessengerType { + return MessengerTypeAWSEmail +} + +func (a *awsSESClient) SendMessage(message Message) error { + err := message.ValidateFor(a.MessengerType()) + if err != nil { + return fmt.Errorf("validating message to send an email through AWS: %w", err) + } + + emailTemplate, err := generateAWSEmail(message, a.senderID) + if err != nil { + return fmt.Errorf("generating AWS SES email template: %w", err) + } + + _, err = a.emailService.SendEmail(emailTemplate) + if err != nil { + return fmt.Errorf("sending AWS SES email: %w", err) + } + + log.Debugf("πŸŽ‰ AWS SES sent an email to the receiver %q", utils.TruncateString(message.ToEmail, 3)) + return nil +} + +// generateAWSEmail generates the email object to send an email through AWS SES. +func generateAWSEmail(message Message, sender string) (*ses.SendEmailInput, error) { + html, err := htmltemplate.ExecuteHTMLTemplateForEmailEmptyBody(htmltemplate.EmptyBodyEmailTemplate{Body: message.Message}) + if err != nil { + return nil, fmt.Errorf("generating html template: %w", err) + } + + return &ses.SendEmailInput{ + Destination: &ses.Destination{ + CcAddresses: []*string{}, + ToAddresses: []*string{ + aws.String(message.ToEmail), + }, + }, + Message: &ses.Message{ + Body: &ses.Body{ + Html: &ses.Content{ + Charset: aws.String("utf-8"), + Data: aws.String(html), + }, + }, + Subject: &ses.Content{ + Charset: aws.String("utf-8"), + Data: aws.String(message.Title), + }, + }, + Source: aws.String(sender), + }, nil +} + +// NewAWSSESClient creates a new AWS SES client, that is used to send emails. +func NewAWSSESClient(accessKeyID, secretAccessKey, region, senderID string) (*awsSESClient, error) { + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return nil, fmt.Errorf("aws accessKeyID is empty") + } + + secretAccessKey = strings.TrimSpace(secretAccessKey) + if secretAccessKey == "" { + return nil, fmt.Errorf("aws secretAccessKey is empty") + } + + region = strings.TrimSpace(region) + if region == "" { + return nil, fmt.Errorf("aws region is empty") + } + + senderID = strings.TrimSpace(senderID) + if err := utils.ValidateEmail(senderID); err != nil { + return nil, fmt.Errorf("aws SES (email) senderID is invalid: %w", err) + } + + awsSession, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(accessKeyID, secretAccessKey, ""), + Region: aws.String(region), + }) + if err != nil { + return nil, fmt.Errorf("creating AWS session: %w", err) + } + + return &awsSESClient{ + senderID: senderID, + emailService: ses.New(awsSession), + }, nil +} + +var _ MessengerClient = (*awsSESClient)(nil) diff --git a/internal/message/aws_ses_client_test.go b/internal/message/aws_ses_client_test.go new file mode 100644 index 000000000..8b708c54a --- /dev/null +++ b/internal/message/aws_ses_client_test.go @@ -0,0 +1,146 @@ +package message + +import ( + "fmt" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockAWSSESClient struct { + mock.Mock +} + +func (m *mockAWSSESClient) SendEmail(input *ses.SendEmailInput) (*ses.SendEmailOutput, error) { + args := m.Called(input) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ses.SendEmailOutput), args.Error(1) +} + +func Test_NewAWSSESClient(t *testing.T) { + // Declare types in advance to make sure these are the types being returned + var gotAWSSESClient *awsSESClient + var err error + + // accessKeyID cannot be empty + gotAWSSESClient, err = NewAWSSESClient("", "", "", "") + require.Nil(t, gotAWSSESClient) + require.EqualError(t, err, "aws accessKeyID is empty") + + // secretAccessKey cannot be empty + gotAWSSESClient, err = NewAWSSESClient("accessKeyID", "", "", "") + require.Nil(t, gotAWSSESClient) + require.EqualError(t, err, "aws secretAccessKey is empty") + + // region cannot be empty + gotAWSSESClient, err = NewAWSSESClient("accessKeyID", "secretAccessKey", "", "") + require.Nil(t, gotAWSSESClient) + require.EqualError(t, err, "aws region is empty") + + // [email] type needs a valid email as a sender ID: + gotAWSSESClient, err = NewAWSSESClient("accessKeyID", "secretAccessKey", "region", "invalid-email") + require.Nil(t, gotAWSSESClient) + require.EqualError(t, err, "aws SES (email) senderID is invalid: the provided email is not valid") + + // [email] all fields are present πŸŽ‰ + gotAWSSESClient, err = NewAWSSESClient("accessKeyID", "secretAccessKey", "region", "foo@test.com") + require.NoError(t, err) + require.NotNil(t, gotAWSSESClient) +} + +func Test_AWSSES_SendMessage_messageIsInvalid(t *testing.T) { + var mAWS MessengerClient = &awsSESClient{} + err := mAWS.SendMessage(Message{}) + require.EqualError(t, err, "validating message to send an email through AWS: invalid message: email cannot be empty") +} + +func Test_AWSSES_SendMessage_errorIsHandledCorrectly(t *testing.T) { + testSenderID := "sender@test.com" + message := Message{ToEmail: "foo@test.com", Title: "test title", Message: "foo bar"} + emailStr, err := generateAWSEmail(message, testSenderID) + require.NoError(t, err) + + mAWSSES := mockAWSSESClient{} + mAWSSES. + On("SendEmail", emailStr). + Return(nil, fmt.Errorf("test AWS SES error")). + Once() + + mAWS := awsSESClient{emailService: &mAWSSES, senderID: "sender@test.com"} + err = mAWS.SendMessage(Message{ToEmail: "foo@test.com", Title: "test title", Message: "foo bar"}) + require.EqualError(t, err, "sending AWS SES email: test AWS SES error") + + mAWSSES.AssertExpectations(t) +} + +func Test_AWSSES_SendMessage_success(t *testing.T) { + testSenderID := "sender@test.com" + message := Message{ToEmail: "foo@test.com", Title: "test title", Message: "foo bar"} + emailStr, err := generateAWSEmail(message, testSenderID) + require.NoError(t, err) + + mAWSSES := mockAWSSESClient{} + mAWSSES. + On("SendEmail", emailStr). + Return(nil, nil). + Once() + + mAWS := awsSESClient{emailService: &mAWSSES, senderID: "sender@test.com"} + err = mAWS.SendMessage(Message{ToEmail: "foo@test.com", Title: "test title", Message: "foo bar"}) + require.NoError(t, err) + + mAWSSES.AssertExpectations(t) +} + +func Test_generateAWSEmail_success(t *testing.T) { + message := Message{ + ToEmail: "receiver@test.com", + Message: "Helo world!", + Title: "title", + } + gotEmail, err := generateAWSEmail(message, "sender@test.com") + require.NoError(t, err) + + wantHTML := ` + + + + + + + + Helo world! + + ` + wantHTML = strings.TrimSpace(wantHTML) + // remove tabs: + wantHTML = strings.ReplaceAll(wantHTML, "\t\t", " ") + wantHTML = strings.ReplaceAll(wantHTML, "\t", "") + + wantEmail := &ses.SendEmailInput{ + Destination: &ses.Destination{ + CcAddresses: []*string{}, + ToAddresses: []*string{aws.String(message.ToEmail)}, + }, + Message: &ses.Message{ + Body: &ses.Body{ + Html: &ses.Content{ + Charset: aws.String("utf-8"), + Data: aws.String(wantHTML), + }, + }, + Subject: &ses.Content{ + Charset: aws.String("utf-8"), + Data: aws.String("title"), + }, + }, + Source: aws.String("sender@test.com"), + } + require.Equal(t, wantEmail, gotEmail) +} diff --git a/internal/message/aws_sns_client.go b/internal/message/aws_sns_client.go new file mode 100644 index 000000000..40d90a00a --- /dev/null +++ b/internal/message/aws_sns_client.go @@ -0,0 +1,92 @@ +package message + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +// awsSNSInterface is used to send sms. +type awsSNSInterface interface { + Publish(input *sns.PublishInput) (*sns.PublishOutput, error) +} + +// awsSNSClient is used to send sms. +type awsSNSClient struct { + snsService awsSNSInterface + senderID string +} + +func (t *awsSNSClient) MessengerType() MessengerType { + return MessengerTypeAWSSMS +} + +func (a *awsSNSClient) SendMessage(message Message) error { + err := message.ValidateFor(a.MessengerType()) + if err != nil { + return fmt.Errorf("validating message to send an SMS through AWS: %w", err) + } + + messageAttributes := map[string]*sns.MessageAttributeValue{ + "AWS.SNS.SMS.SMSType": {StringValue: aws.String("Transactional"), DataType: aws.String("String")}, + } + if a.senderID != "" { + // According with AWS, senderID is optional: https://docs.aws.amazon.com/sns/latest/dg/sms_publish-to-phone.html#sms_publish_sdk + messageAttributes["AWS.SNS.SMS.SenderID"] = &sns.MessageAttributeValue{StringValue: aws.String(a.senderID), DataType: aws.String("String")} + } + + params := &sns.PublishInput{ + PhoneNumber: aws.String(message.ToPhoneNumber), + Message: aws.String(message.Message), + MessageAttributes: messageAttributes, + } + + _, err = a.snsService.Publish(params) + if err != nil { + return fmt.Errorf("sending AWS SNS SMS: %w", err) + } + + log.Debugf("πŸŽ‰ AWS SNS sent an SMS to the phoneNumber %q", utils.TruncateString(message.ToPhoneNumber, 3)) + return nil +} + +// NewAWSSNSClient creates a new awsSNSClient, that is used to send SMS messages. +func NewAWSSNSClient(accessKeyID, secretAccessKey, region, senderID string) (*awsSNSClient, error) { + accessKeyID = strings.TrimSpace(accessKeyID) + if accessKeyID == "" { + return nil, fmt.Errorf("aws accessKeyID is empty") + } + + secretAccessKey = strings.TrimSpace(secretAccessKey) + if secretAccessKey == "" { + return nil, fmt.Errorf("aws secretAccessKey is empty") + } + + region = strings.TrimSpace(region) + if region == "" { + return nil, fmt.Errorf("aws region is empty") + } + + senderID = strings.TrimSpace(senderID) + + awsSession, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(accessKeyID, secretAccessKey, ""), + Region: aws.String(region), + }) + if err != nil { + return nil, fmt.Errorf("creating AWS session: %w", err) + } + + return &awsSNSClient{ + senderID: senderID, + snsService: sns.New(awsSession), + }, nil +} + +var _ MessengerClient = (*awsSNSClient)(nil) diff --git a/internal/message/aws_sns_client_test.go b/internal/message/aws_sns_client_test.go new file mode 100644 index 000000000..dc6a79bce --- /dev/null +++ b/internal/message/aws_sns_client_test.go @@ -0,0 +1,110 @@ +package message + +import ( + "fmt" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockAWSSNSClient struct { + mock.Mock +} + +func (m *mockAWSSNSClient) Publish(input *sns.PublishInput) (*sns.PublishOutput, error) { + args := m.Called(input) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*sns.PublishOutput), args.Error(1) +} + +func Test_NewAWSSNSClient(t *testing.T) { + // Declare types in advance to make sure these are the types being returned + var gotAWSSNSClient *awsSNSClient + var err error + + // accessKeyID cannot be empty + gotAWSSNSClient, err = NewAWSSNSClient("", "", "", "") + require.Nil(t, gotAWSSNSClient) + require.EqualError(t, err, "aws accessKeyID is empty") + + // secretAccessKey cannot be empty + gotAWSSNSClient, err = NewAWSSNSClient("accessKeyID", "", "", "") + require.Nil(t, gotAWSSNSClient) + require.EqualError(t, err, "aws secretAccessKey is empty") + + // region cannot be empty + gotAWSSNSClient, err = NewAWSSNSClient("accessKeyID", "secretAccessKey", "", "") + require.Nil(t, gotAWSSNSClient) + require.EqualError(t, err, "aws region is empty") + + // [sms] type doesn't need a sender ID: + gotAWSSNSClient, err = NewAWSSNSClient("accessKeyID", "secretAccessKey", "region", " ") + require.NoError(t, err) + require.NotNil(t, gotAWSSNSClient) + + // [sms] all fields are present πŸŽ‰ + gotAWSSNSClient, err = NewAWSSNSClient("accessKeyID", "secretAccessKey", "region", "testSenderID") + require.NoError(t, err) + require.NotNil(t, gotAWSSNSClient) +} + +func Test_AWSSNS_SendMessage_messageIsInvalid(t *testing.T) { + var mAWS MessengerClient = &awsSNSClient{} + err := mAWS.SendMessage(Message{}) + require.EqualError(t, err, "validating message to send an SMS through AWS: invalid message: phone number cannot be empty") +} + +func Test_AWSSNS_SendMessage_errorIsHandledCorrectly(t *testing.T) { + // check if error is handled correctly + testPhoneNumber := "+14155555555" + testMessage := "foo bar" + testSenderID := "senderID" + mAWSSNS := mockAWSSNSClient{} + mAWSSNS. + On("Publish", &sns.PublishInput{ + PhoneNumber: aws.String(testPhoneNumber), + Message: aws.String(testMessage), + MessageAttributes: map[string]*sns.MessageAttributeValue{ + "AWS.SNS.SMS.SenderID": {StringValue: aws.String(testSenderID), DataType: aws.String("String")}, + "AWS.SNS.SMS.SMSType": {StringValue: aws.String("Transactional"), DataType: aws.String("String")}, + }, + }). + Return(nil, fmt.Errorf("test AWS SNS error")). + Once() + + mAWS := awsSNSClient{snsService: &mAWSSNS, senderID: "senderID"} + err := mAWS.SendMessage(Message{ToPhoneNumber: "+14155555555", Message: "foo bar"}) + require.EqualError(t, err, "sending AWS SNS SMS: test AWS SNS error") + + mAWSSNS.AssertExpectations(t) +} + +func Test_AWSSNS_SendMessage_success(t *testing.T) { + // check if error is handled correctly + testPhoneNumber := "+14152222222" + testMessage := "foo bar" + testSenderID := "senderID" + mAWSSNS := mockAWSSNSClient{} + mAWSSNS. + On("Publish", &sns.PublishInput{ + PhoneNumber: aws.String(testPhoneNumber), + Message: aws.String(testMessage), + MessageAttributes: map[string]*sns.MessageAttributeValue{ + "AWS.SNS.SMS.SenderID": {StringValue: aws.String(testSenderID), DataType: aws.String("String")}, + "AWS.SNS.SMS.SMSType": {StringValue: aws.String("Transactional"), DataType: aws.String("String")}, + }, + }). + Return(nil, nil). + Once() + + mAWS := awsSNSClient{snsService: &mAWSSNS, senderID: "senderID"} + err := mAWS.SendMessage(Message{ToPhoneNumber: "+14152222222", Message: "foo bar"}) + require.NoError(t, err) + + mAWSSNS.AssertExpectations(t) +} diff --git a/internal/message/dry_run_client.go b/internal/message/dry_run_client.go new file mode 100644 index 000000000..05509b06d --- /dev/null +++ b/internal/message/dry_run_client.go @@ -0,0 +1,31 @@ +package message + +import ( + "fmt" + "strings" +) + +type dryRunClient struct{} + +func (c *dryRunClient) SendMessage(message Message) error { + recipient := message.ToEmail + if message.ToEmail == "" { + recipient = message.ToPhoneNumber + } + + fmt.Println(strings.Repeat("-", 79)) + fmt.Println("Recipient:", recipient) + fmt.Println("Subject:", message.Title) + fmt.Println("Content:", message.Message) + fmt.Println(strings.Repeat("-", 79)) + + return nil +} + +func (c *dryRunClient) MessengerType() MessengerType { + return MessengerTypeDryRun +} + +func NewDryRunClient() (MessengerClient, error) { + return &dryRunClient{}, nil +} diff --git a/internal/message/dry_run_client_test.go b/internal/message/dry_run_client_test.go new file mode 100644 index 000000000..8e93dcd80 --- /dev/null +++ b/internal/message/dry_run_client_test.go @@ -0,0 +1,79 @@ +package message + +import ( + "io" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DryRunClient(t *testing.T) { + cc, _ := NewDryRunClient() + + // Email + stdOut := os.Stdout + + r, w, err := os.Pipe() + require.NoError(t, err) + + os.Stdout = w + + msg := Message{ + ToPhoneNumber: "", + ToEmail: "email@email.com", + Title: "My Message Title", + Message: "My email content", + } + err = cc.SendMessage(msg) + require.NoError(t, err) + + w.Close() + os.Stdout = stdOut + + buf := new(strings.Builder) + _, err = io.Copy(buf, r) + require.NoError(t, err) + + expected := `------------------------------------------------------------------------------- +Recipient: email@email.com +Subject: My Message Title +Content: My email content +------------------------------------------------------------------------------- +` + assert.Equal(t, expected, buf.String()) + + // SMS + stdOut = os.Stdout + + r, w, err = os.Pipe() + require.NoError(t, err) + + os.Stdout = w + + msg = Message{ + ToPhoneNumber: "+11111111111", + ToEmail: "", + Title: "My Message Title", + Message: "My SMS content", + } + err = cc.SendMessage(msg) + require.NoError(t, err) + + w.Close() + os.Stdout = stdOut + + buf = new(strings.Builder) + _, err = io.Copy(buf, r) + require.NoError(t, err) + + expected = `------------------------------------------------------------------------------- +Recipient: +11111111111 +Subject: My Message Title +Content: My SMS content +------------------------------------------------------------------------------- +` + assert.Equal(t, expected, buf.String()) +} diff --git a/internal/message/main.go b/internal/message/main.go new file mode 100644 index 000000000..d34fd68c1 --- /dev/null +++ b/internal/message/main.go @@ -0,0 +1,91 @@ +package message + +import ( + "fmt" + "strings" + + "golang.org/x/exp/slices" +) + +type MessengerType string + +// ATTENTION: when adding a new type, make ure to update the MessengerType methods! +const ( + // MessengerTypeTwilioSMS is used to send SMS messages using Twilio. + MessengerTypeTwilioSMS MessengerType = "TWILIO_SMS" + // MessengerTypeAWSSMS is used to send SMS messages using AWS SNS. + MessengerTypeAWSSMS MessengerType = "AWS_SMS" + // MessengerTypeAWSEmail is used to send emails using AWS SES. + MessengerTypeAWSEmail MessengerType = "AWS_EMAIL" + // MessengerTypeDryRun is used for development environment + MessengerTypeDryRun MessengerType = "DRY_RUN" +) + +func (mt MessengerType) All() []MessengerType { + return []MessengerType{MessengerTypeTwilioSMS, MessengerTypeAWSSMS, MessengerTypeAWSEmail, MessengerTypeDryRun} +} + +func ParseMessengerType(messengerTypeStr string) (MessengerType, error) { + messageTypeStrUpper := strings.ToUpper(messengerTypeStr) + mType := MessengerType(messageTypeStrUpper) + + if slices.Contains(MessengerType("").All(), mType) { + return mType, nil + } + + return "", fmt.Errorf("invalid message sender type %q", messageTypeStrUpper) +} + +func (mt MessengerType) ValidSMSTypes() []MessengerType { + return []MessengerType{MessengerTypeDryRun, MessengerTypeTwilioSMS, MessengerTypeAWSSMS} +} + +func (mt MessengerType) ValidEmailTypes() []MessengerType { + return []MessengerType{MessengerTypeDryRun, MessengerTypeAWSEmail} +} + +func (mt MessengerType) IsSMS() bool { + return slices.Contains(mt.ValidSMSTypes(), mt) +} + +func (mt MessengerType) IsEmail() bool { + return slices.Contains(mt.ValidEmailTypes(), mt) +} + +type MessengerOptions struct { + MessengerType MessengerType + Environment string + + // Twilio + TwilioAccountSID string + TwilioAuthToken string + TwilioServiceSID string + + // AWS + AWSAccessKeyID string + AWSSecretAccessKey string + AWSRegion string + // AWS SNS (SMS messages) + AWSSNSSenderID string + // AWS SES (EMAIL messages) + AWSSESSenderID string +} + +func GetClient(opts MessengerOptions) (MessengerClient, error) { + switch opts.MessengerType { + case MessengerTypeTwilioSMS: + return NewTwilioClient(opts.TwilioAccountSID, opts.TwilioAuthToken, opts.TwilioServiceSID) + + case MessengerTypeAWSSMS: + return NewAWSSNSClient(opts.AWSAccessKeyID, opts.AWSSecretAccessKey, opts.AWSRegion, opts.AWSSNSSenderID) + + case MessengerTypeAWSEmail: + return NewAWSSESClient(opts.AWSAccessKeyID, opts.AWSSecretAccessKey, opts.AWSRegion, opts.AWSSESSenderID) + + case MessengerTypeDryRun: + return NewDryRunClient() + + default: + return nil, fmt.Errorf("unknown message sender type: %q", opts.MessengerType) + } +} diff --git a/internal/message/main_test.go b/internal/message/main_test.go new file mode 100644 index 000000000..48a68fd81 --- /dev/null +++ b/internal/message/main_test.go @@ -0,0 +1,81 @@ +package message + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseMessengerType(t *testing.T) { + testCases := []struct { + messengerType string + wantErr error + }{ + {wantErr: fmt.Errorf("invalid message sender type \"\"")}, + {messengerType: "foo_BAR", wantErr: fmt.Errorf("invalid message sender type \"FOO_BAR\"")}, + {messengerType: "TWILIO_SMS"}, + {messengerType: "tWiLiO_SMS"}, + {messengerType: "AWS_SMS"}, + {messengerType: "AWS_EMAIL"}, + {messengerType: "DRY_RUN"}, + } + + for _, tc := range testCases { + t.Run("messengerType: "+tc.messengerType, func(t *testing.T) { + _, err := ParseMessengerType(tc.messengerType) + if tc.wantErr != nil { + assert.Equal(t, tc.wantErr, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_GetClient(t *testing.T) { + // MessengerTypeTwilioSMS + messengerType := MessengerTypeTwilioSMS + opts := MessengerOptions{ + MessengerType: messengerType, + TwilioAccountSID: "accountSid", + TwilioAuthToken: "authToken", + TwilioServiceSID: "senderID", + } + gotClient, err := GetClient(opts) + require.NoError(t, err) + require.IsType(t, &twilioClient{}, gotClient) + + // MessengerTypeAWSSMS + messengerType = MessengerTypeAWSSMS + opts = MessengerOptions{ + MessengerType: messengerType, + AWSAccessKeyID: "accessKeyID", + AWSSecretAccessKey: "secretAccessKey", + AWSRegion: "region", + AWSSNSSenderID: "mySenderID", + } + gotClient, err = GetClient(opts) + require.NoError(t, err) + require.IsType(t, &awsSNSClient{}, gotClient) + gotAWSSNSClient, ok := gotClient.(*awsSNSClient) + require.True(t, ok) + require.NotNil(t, gotAWSSNSClient.snsService) + + // MessengerTypeAWSEmail + messengerType = MessengerTypeAWSEmail + opts = MessengerOptions{ + MessengerType: messengerType, + AWSAccessKeyID: "accessKeyID", + AWSSecretAccessKey: "secretAccessKey", + AWSRegion: "region", + AWSSESSenderID: "foo@test.com", + } + gotClient, err = GetClient(opts) + require.NoError(t, err) + require.IsType(t, &awsSESClient{}, gotClient) + gotAWSSESClient, ok := gotClient.(*awsSESClient) + require.True(t, ok) + require.NotNil(t, gotAWSSESClient.emailService) +} diff --git a/internal/message/message.go b/internal/message/message.go new file mode 100644 index 000000000..265d2f040 --- /dev/null +++ b/internal/message/message.go @@ -0,0 +1,40 @@ +package message + +import ( + "fmt" + "strings" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type Message struct { + ToPhoneNumber string + ToEmail string + Message string + Title string +} + +// ValidateFor validates if the message object is valid for the given messengerType. +func (s *Message) ValidateFor(messengerType MessengerType) error { + if messengerType.IsSMS() { + if err := utils.ValidatePhoneNumber(s.ToPhoneNumber); err != nil { + return fmt.Errorf("invalid message: %w", err) + } + } + + if messengerType.IsEmail() { + if err := utils.ValidateEmail(s.ToEmail); err != nil { + return fmt.Errorf("invalid message: %w", err) + } + + if strings.Trim(s.Title, " ") == "" { + return fmt.Errorf("title is empty") + } + } + + if strings.Trim(s.Message, " ") == "" { + return fmt.Errorf("message is empty") + } + + return nil +} diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 000000000..42ea6633c --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,91 @@ +package message + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_message_Validate(t *testing.T) { + testCases := []struct { + name string + messengerType MessengerType + message Message + wantErr error + }{ + // SMS types + { + name: "SMS types need a non-empty phone number", + messengerType: MessengerTypeTwilioSMS, + message: Message{}, + wantErr: fmt.Errorf("invalid message: phone number cannot be empty"), + }, + { + name: "SMS types need a valid phone number", + messengerType: MessengerTypeTwilioSMS, + message: Message{ToPhoneNumber: "invalid-phone"}, + wantErr: fmt.Errorf("invalid message: the provided phone number is not a valid E.164 number"), + }, + { + name: "[sms] message cannot be empty", + messengerType: MessengerTypeTwilioSMS, + message: Message{ToPhoneNumber: "+14152111111", Message: " "}, + wantErr: fmt.Errorf("message is empty"), + }, + { + name: "[sms] all fields are present for Twilio πŸŽ‰", + messengerType: MessengerTypeTwilioSMS, + message: Message{ToPhoneNumber: "+14152111111", Message: "foo bar"}, + wantErr: nil, + }, + { + name: "[sms] all fields are present for AWS SNS πŸŽ‰", + messengerType: MessengerTypeAWSSMS, + message: Message{ToPhoneNumber: "+14152111111", Message: "foo bar"}, + wantErr: nil, + }, + // Email types + { + name: "Email types need a non-empty email address", + messengerType: MessengerTypeAWSEmail, + message: Message{}, + wantErr: fmt.Errorf("invalid message: email cannot be empty"), + }, + { + name: "Email types need a valid email address", + messengerType: MessengerTypeAWSEmail, + message: Message{ToEmail: "invalid-email"}, + wantErr: fmt.Errorf("invalid message: the provided email is not valid"), + }, + { + name: "Email types need a title", + messengerType: MessengerTypeAWSEmail, + message: Message{ToEmail: "foo@test.com", Title: " "}, + wantErr: fmt.Errorf("title is empty"), + }, + { + name: "[sms] message cannot be empty", + messengerType: MessengerTypeAWSEmail, + message: Message{ToEmail: "foo@test.com", Title: "My title"}, + wantErr: fmt.Errorf("message is empty"), + }, + { + name: "[email] all fields are present for AWS email πŸŽ‰", + messengerType: MessengerTypeAWSEmail, + message: Message{ToEmail: "foo@test.com", Title: "My title", Message: "foo bar"}, + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.message.ValidateFor(tc.messengerType) + if tc.wantErr != nil { + require.EqualError(t, err, tc.wantErr.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/message/messenger_client.go b/internal/message/messenger_client.go new file mode 100644 index 000000000..913bedf1a --- /dev/null +++ b/internal/message/messenger_client.go @@ -0,0 +1,6 @@ +package message + +type MessengerClient interface { + SendMessage(message Message) error + MessengerType() MessengerType +} diff --git a/internal/message/mocks.go b/internal/message/mocks.go new file mode 100644 index 000000000..cabd6b72f --- /dev/null +++ b/internal/message/mocks.go @@ -0,0 +1,21 @@ +package message + +import ( + "github.com/stretchr/testify/mock" +) + +type MessengerClientMock struct { + mock.Mock +} + +func (mc *MessengerClientMock) SendMessage(message Message) error { + args := mc.Called(message) + return args.Error(0) +} + +func (mc *MessengerClientMock) MessengerType() MessengerType { + args := mc.Called() + return args.Get(0).(MessengerType) +} + +var _ MessengerClient = (*MessengerClientMock)(nil) diff --git a/internal/message/twilio_client.go b/internal/message/twilio_client.go new file mode 100644 index 000000000..1d11602a9 --- /dev/null +++ b/internal/message/twilio_client.go @@ -0,0 +1,88 @@ +package message + +import ( + "fmt" + "strings" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/twilio/twilio-go" + twilioApi "github.com/twilio/twilio-go/rest/api/v2010" +) + +type twilioApiInterface interface { + CreateMessage(params *twilioApi.CreateMessageParams) (*twilioApi.ApiV2010Message, error) +} + +type twilioClient struct { + apiService twilioApiInterface + senderID string +} + +func (t *twilioClient) MessengerType() MessengerType { + return MessengerTypeTwilioSMS +} + +func (t *twilioClient) CreateMessage(params *twilioApi.CreateMessageParams) (*twilioApi.ApiV2010Message, error) { + return t.apiService.CreateMessage(params) +} + +func (t *twilioClient) SendMessage(message Message) error { + err := message.ValidateFor(t.MessengerType()) + if err != nil { + return fmt.Errorf("validating SMS message: %w", err) + } + + resp, err := t.CreateMessage(&twilioApi.CreateMessageParams{ + To: &message.ToPhoneNumber, + Body: &message.Message, + MessagingServiceSid: &t.senderID, + }) + if err != nil { + return fmt.Errorf("sending Twilio SMS: %w", err) + } + + if resp.ErrorCode != nil || resp.ErrorMessage != nil { + var errorCode string + if resp.ErrorCode != nil { + errorCode = fmt.Sprintf("%d", *resp.ErrorCode) + } + + var errorMessage string + if resp.ErrorMessage != nil { + errorMessage = *resp.ErrorMessage + } + + return fmt.Errorf("sending Twilio SMS responded an error {code: %q, message: %q}", errorCode, errorMessage) + } + + log.Debugf("Twilio sent an SMS to the phoneNumber %q", utils.TruncateString(message.ToPhoneNumber, 3)) + return nil +} + +func NewTwilioClient(accountSid, authToken, senderID string) (*twilioClient, error) { + accountSid = strings.TrimSpace(accountSid) + if accountSid == "" { + return nil, fmt.Errorf("twilio accountSid is empty") + } + + authToken = strings.TrimSpace(authToken) + if authToken == "" { + return nil, fmt.Errorf("twilio authToken is empty") + } + + senderID = strings.TrimSpace(senderID) + if senderID == "" { + return nil, fmt.Errorf("twilio senderID is empty") + } + + return &twilioClient{ + apiService: twilio.NewRestClientWithParams(twilio.ClientParams{ + Username: accountSid, + Password: authToken, + }).Api, + senderID: senderID, + }, nil +} + +var _ MessengerClient = (*twilioClient)(nil) diff --git a/internal/message/twilio_client_test.go b/internal/message/twilio_client_test.go new file mode 100644 index 000000000..7651c1696 --- /dev/null +++ b/internal/message/twilio_client_test.go @@ -0,0 +1,139 @@ +package message + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/twilio/twilio-go" + twilioAPI "github.com/twilio/twilio-go/rest/api/v2010" +) + +type mockTwilioApi struct { + mock.Mock +} + +func (m *mockTwilioApi) CreateMessage(params *twilioAPI.CreateMessageParams) (*twilioAPI.ApiV2010Message, error) { + args := m.Called(params) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*twilioAPI.ApiV2010Message), args.Error(1) +} + +func Test_NewTwilioClient(t *testing.T) { + // Declare types in advance to make sure these are the types being returned + var gotTwilioClient MessengerClient + var err error + + // accountSid cannot be empty + gotTwilioClient, err = NewTwilioClient("", "", "") + require.Nil(t, gotTwilioClient) + require.EqualError(t, err, "twilio accountSid is empty") + + // accountSid cannot be empty + gotTwilioClient, err = NewTwilioClient("accountSid", " ", "") + require.Nil(t, gotTwilioClient) + require.EqualError(t, err, "twilio authToken is empty") + + // senderID cannot be empty + gotTwilioClient, err = NewTwilioClient("accountSid", "authToken", "") + require.Nil(t, gotTwilioClient) + require.EqualError(t, err, "twilio senderID is empty") + + // all fields are present πŸŽ‰ + gotTwilioClient, err = NewTwilioClient("accountSid", "authToken", "senderID") + require.NoError(t, err) + wantTwilioClient := &twilioClient{ + apiService: twilio.NewRestClientWithParams(twilio.ClientParams{ + Username: "accountSid", + Password: "authToken", + }).Api, + senderID: "senderID", + } + require.Equal(t, wantTwilioClient, gotTwilioClient) +} + +func Test_Twilio_messengerType(t *testing.T) { + tw := twilioClient{} + require.Equal(t, MessengerTypeTwilioSMS, tw.MessengerType()) +} + +func Test_Twilio_SendMessage_messageIsInvalid(t *testing.T) { + var mTwilio MessengerClient = &twilioClient{} + err := mTwilio.SendMessage(Message{}) + require.EqualError(t, err, "validating SMS message: invalid message: phone number cannot be empty") +} + +func Test_Twilio_SendMessage_errorIsHandledCorrectly(t *testing.T) { + // check if error is handled correctly + testPhoneNumber := "+14155111111" + testMessage := "foo bar" + testSenderID := "senderID" + mTwilioApi := mockTwilioApi{} + mTwilioApi. + On("CreateMessage", &twilioAPI.CreateMessageParams{ + To: &testPhoneNumber, + Body: &testMessage, + MessagingServiceSid: &testSenderID, + }). + Return(nil, fmt.Errorf("test twilio error")). + Once() + + mTwilio := twilioClient{apiService: &mTwilioApi, senderID: "senderID"} + err := mTwilio.SendMessage(Message{ToPhoneNumber: "+14155111111", Message: "foo bar"}) + require.EqualError(t, err, "sending Twilio SMS: test twilio error") + + mTwilioApi.AssertExpectations(t) +} + +func Test_Twilio_SendMessage_doesntReturnErrorButResponseContainsErrorEmbedded(t *testing.T) { + // validate the case where the response contains an error message, + // despite the method succeeding + testPhoneNumber2 := "+14152222222" + testMessage2 := "foo bar" + testSenderID := "senderID" + + wantErrCode := 12345 + wantErrMessage := "Foo bar error message" + + mTwilioApi := mockTwilioApi{} + mTwilioApi. + On("CreateMessage", &twilioAPI.CreateMessageParams{ + To: &testPhoneNumber2, + Body: &testMessage2, + MessagingServiceSid: &testSenderID, + }). + Return(&twilioAPI.ApiV2010Message{ + ErrorCode: &wantErrCode, + ErrorMessage: &wantErrMessage, + }, nil). + Once() + + mTwilio := twilioClient{apiService: &mTwilioApi, senderID: "senderID"} + err := mTwilio.SendMessage(Message{ToPhoneNumber: "+14152222222", Message: "foo bar"}) + require.EqualError(t, err, `sending Twilio SMS responded an error {code: "12345", message: "Foo bar error message"}`) +} + +func Test_Twilio_SendMessage_success(t *testing.T) { + // check if error is handled correctly + testPhoneNumber := "+14153333333" + testMessage := "foo bar" + testSenderID := "senderID" + mTwilioApi := mockTwilioApi{} + mTwilioApi. + On("CreateMessage", &twilioAPI.CreateMessageParams{ + To: &testPhoneNumber, + Body: &testMessage, + MessagingServiceSid: &testSenderID, + }). + Return(&twilioAPI.ApiV2010Message{}, nil). + Once() + + mTwilio := twilioClient{apiService: &mTwilioApi, senderID: "senderID"} + err := mTwilio.SendMessage(Message{ToPhoneNumber: "+14153333333", Message: "foo bar"}) + require.NoError(t, err) + + mTwilioApi.AssertExpectations(t) +} diff --git a/internal/monitor/main.go b/internal/monitor/main.go new file mode 100644 index 000000000..2df87c145 --- /dev/null +++ b/internal/monitor/main.go @@ -0,0 +1,43 @@ +package monitor + +import ( + "fmt" + "strings" +) + +type MetricType string + +const ( + MetricTypePrometheus MetricType = "PROMETHEUS" + MetricTypeTSSPrometheus MetricType = "TSS_PROMETHEUS" +) + +func ParseMetricType(metricTypeStr string) (MetricType, error) { + metricTypeStrUpper := strings.ToUpper(metricTypeStr) + mType := MetricType(metricTypeStrUpper) + + switch mType { + case MetricTypePrometheus: + return mType, nil + case MetricTypeTSSPrometheus: + return mType, nil + default: + return "", fmt.Errorf("invalid metric type %q", metricTypeStrUpper) + } +} + +type MetricOptions struct { + MetricType MetricType + Environment string +} + +func GetClient(opts MetricOptions) (MonitorClient, error) { + switch opts.MetricType { + case MetricTypePrometheus: + return NewPrometheusClient() + case MetricTypeTSSPrometheus: + return NewTSSPrometheusClient() + default: + return nil, fmt.Errorf("unknown metric type: %q", opts.MetricType) + } +} diff --git a/internal/monitor/main_test.go b/internal/monitor/main_test.go new file mode 100644 index 000000000..9b850e3f8 --- /dev/null +++ b/internal/monitor/main_test.go @@ -0,0 +1,48 @@ +package monitor + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ParseMetricType(t *testing.T) { + testCases := []struct { + metricTypeStr string + expectedMetricType MetricType + wantErr error + }{ + {wantErr: fmt.Errorf("invalid metric type \"\"")}, + {metricTypeStr: "MOCKMETRICTYPE", wantErr: fmt.Errorf("invalid metric type \"MOCKMETRICTYPE\"")}, + {metricTypeStr: "prometheus", expectedMetricType: MetricTypePrometheus}, + {metricTypeStr: "PromeTHEUS", expectedMetricType: MetricTypePrometheus}, + } + for _, tc := range testCases { + t.Run("metricType: "+tc.metricTypeStr, func(t *testing.T) { + metricType, err := ParseMetricType(tc.metricTypeStr) + assert.Equal(t, tc.expectedMetricType, metricType) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func Test_GetClient(t *testing.T) { + metricOptions := MetricOptions{} + + t.Run("get prometheus monitor client", func(t *testing.T) { + metricOptions.MetricType = MetricTypePrometheus + + gotClient, err := GetClient(metricOptions) + assert.NoError(t, err) + assert.IsType(t, &prometheusClient{}, gotClient) + }) + + t.Run("error metric passed is invalid", func(t *testing.T) { + metricOptions.MetricType = "MOCKMETRICTYPE" + + gotClient, err := GetClient(metricOptions) + assert.Nil(t, gotClient) + assert.EqualError(t, err, "unknown metric type: \"MOCKMETRICTYPE\"") + }) +} diff --git a/internal/monitor/metric_tags.go b/internal/monitor/metric_tags.go new file mode 100644 index 000000000..9a78ce62e --- /dev/null +++ b/internal/monitor/metric_tags.go @@ -0,0 +1,19 @@ +package monitor + +type MetricTag string + +const ( + SuccessfulQueryDurationTag MetricTag = "successful_queries_duration" + FailureQueryDurationTag MetricTag = "failure_queries_duration" + HttpRequestDurationTag MetricTag = "requests_duration_seconds" + DisbursementsCounterTag MetricTag = "disbursements_counter" +) + +func (m MetricTag) ListAll() []MetricTag { + return []MetricTag{ + SuccessfulQueryDurationTag, + FailureQueryDurationTag, + HttpRequestDurationTag, + DisbursementsCounterTag, + } +} diff --git a/internal/monitor/monitor_client.go b/internal/monitor/monitor_client.go new file mode 100644 index 000000000..894d4f897 --- /dev/null +++ b/internal/monitor/monitor_client.go @@ -0,0 +1,16 @@ +package monitor + +import ( + "net/http" + "time" +) + +type MonitorClient interface { + GetMetricHttpHandler() http.Handler + GetMetricType() MetricType + MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) + MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) + MonitorCounters(tag MetricTag, labels map[string]string) + MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) + MonitorHistogram(value float64, tag MetricTag, labels map[string]string) +} diff --git a/internal/monitor/monitor_labels.go b/internal/monitor/monitor_labels.go new file mode 100644 index 000000000..32f7c1f3d --- /dev/null +++ b/internal/monitor/monitor_labels.go @@ -0,0 +1,25 @@ +package monitor + +type HttpRequestLabels struct { + Status string + Route string + Method string +} + +type DBQueryLabels struct { + QueryType string +} + +type DisbursementLabels struct { + Asset string + Country string + Wallet string +} + +func (d DisbursementLabels) ToMap() map[string]string { + return map[string]string{ + "asset": d.Asset, + "country": d.Country, + "wallet": d.Wallet, + } +} diff --git a/internal/monitor/monitor_service.go b/internal/monitor/monitor_service.go new file mode 100644 index 000000000..077975679 --- /dev/null +++ b/internal/monitor/monitor_service.go @@ -0,0 +1,103 @@ +package monitor + +import ( + "fmt" + "net/http" + "time" +) + +type MonitorServiceInterface interface { + Start(opts MetricOptions) error + GetMetricType() (MetricType, error) + GetMetricHttpHandler() (http.Handler, error) + MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) error + MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) error + MonitorCounters(tag MetricTag, labels map[string]string) error + MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) error + MonitorHistogram(value float64, tag MetricTag, labels map[string]string) error +} + +type MonitorService struct { + monitorClient MonitorClient +} + +func (m *MonitorService) Start(opts MetricOptions) error { + if m.monitorClient != nil { + return fmt.Errorf("service already initialized") + } + + monitorClient, err := GetClient(opts) + if err != nil { + return fmt.Errorf("error creating monitor client: %w", err) + } + + m.monitorClient = monitorClient + + return nil +} + +func (m *MonitorService) GetMetricType() (MetricType, error) { + if m.monitorClient == nil { + return "", fmt.Errorf("client was not initialized") + } + + return m.monitorClient.GetMetricType(), nil +} + +func (m *MonitorService) GetMetricHttpHandler() (http.Handler, error) { + if m.monitorClient == nil { + return nil, fmt.Errorf("client was not initialized") + } + + return m.monitorClient.GetMetricHttpHandler(), nil +} + +func (m *MonitorService) MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) error { + if m.monitorClient == nil { + return fmt.Errorf("client was not initialized") + } + + m.monitorClient.MonitorHttpRequestDuration(duration, labels) + + return nil +} + +func (m *MonitorService) MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) error { + if m.monitorClient == nil { + return fmt.Errorf("client was not initialized") + } + + m.monitorClient.MonitorDBQueryDuration(duration, tag, labels) + + return nil +} + +func (m *MonitorService) MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) error { + if m.monitorClient == nil { + return fmt.Errorf("client was not initialized") + } + + m.monitorClient.MonitorDuration(duration, tag, labels) + + return nil +} + +func (m *MonitorService) MonitorHistogram(value float64, tag MetricTag, labels map[string]string) error { + if m.monitorClient == nil { + return fmt.Errorf("client was not initialized") + } + + m.monitorClient.MonitorHistogram(value, tag, labels) + + return nil +} + +func (m *MonitorService) MonitorCounters(tag MetricTag, labels map[string]string) error { + if m.monitorClient == nil { + return fmt.Errorf("client was not initialized") + } + + m.monitorClient.MonitorCounters(tag, labels) + + return nil +} diff --git a/internal/monitor/monitor_service_mocks.go b/internal/monitor/monitor_service_mocks.go new file mode 100644 index 000000000..4257ed4a4 --- /dev/null +++ b/internal/monitor/monitor_service_mocks.go @@ -0,0 +1,52 @@ +package monitor + +import ( + "net/http" + "time" + + "github.com/stretchr/testify/mock" +) + +type MockMonitorService struct { + mock.Mock +} + +func (m *MockMonitorService) GetMetricHttpHandler() (http.Handler, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(http.Handler), args.Error(1) +} + +func (m *MockMonitorService) GetMetricType() (MetricType, error) { + args := m.Called() + if args.Get(0) == nil { + return "", args.Error(1) + } + return args.Get(0).(MetricType), args.Error(1) +} + +func (m *MockMonitorService) MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) error { + return m.Called(duration, labels).Error(0) +} + +func (m *MockMonitorService) MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) error { + return m.Called(duration, tag, labels).Error(0) +} + +func (m *MockMonitorService) MonitorCounters(tag MetricTag, labels map[string]string) error { + return m.Called(tag, labels).Error(0) +} + +func (m *MockMonitorService) MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) error { + return m.Called(duration, tag, labels).Error(0) +} + +func (m *MockMonitorService) MonitorHistogram(value float64, tag MetricTag, labels map[string]string) error { + return m.Called(value, tag, labels).Error(0) +} + +func (m *MockMonitorService) Start(opts MetricOptions) error { + return m.Called(opts).Error(0) +} diff --git a/internal/monitor/monitor_service_test.go b/internal/monitor/monitor_service_test.go new file mode 100644 index 000000000..ddbe8684f --- /dev/null +++ b/internal/monitor/monitor_service_test.go @@ -0,0 +1,233 @@ +package monitor + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockMonitorClient struct { + mock.Mock +} + +func (m *mockMonitorClient) GetMetricHttpHandler() http.Handler { + return m.Called().Get(0).(http.Handler) +} + +func (m *mockMonitorClient) GetMetricType() MetricType { + return m.Called().Get(0).(MetricType) +} + +func (m *mockMonitorClient) MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) { + m.Called(duration, labels) +} + +func (m *mockMonitorClient) MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) { + m.Called(duration, tag, labels) +} + +func (m *mockMonitorClient) MonitorCounters(tag MetricTag, labels map[string]string) { + m.Called(tag, labels) +} + +func (m *mockMonitorClient) MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) { + m.Called(duration, tag, labels) +} + +func (m *mockMonitorClient) MonitorHistogram(value float64, tag MetricTag, labels map[string]string) { + m.Called(value, tag, labels) +} + +func Test_MetricsService_Start(t *testing.T) { + monitorService := &MonitorService{} + metricOptions := MetricOptions{} + + t.Run("start prometheus service metric", func(t *testing.T) { + metricOptions.MetricType = "PROMETHEUS" + err := monitorService.Start(metricOptions) + require.NoError(t, err) + + require.IsType(t, &prometheusClient{}, monitorService.monitorClient) + assert.NotNil(t, monitorService.monitorClient) + }) + + t.Run("error monitor service already initialized", func(t *testing.T) { + metricOptions.MetricType = "MOCK_METRIC_TYPE" + + err := monitorService.Start(metricOptions) + require.EqualError(t, err, "service already initialized") + }) + + t.Run("error unknown metric type", func(t *testing.T) { + monitorService.monitorClient = nil + + metricOptions.MetricType = "MOCK_METRIC_TYPE" + err := monitorService.Start(metricOptions) + require.EqualError(t, err, "error creating monitor client: unknown metric type: \"MOCK_METRIC_TYPE\"") + }) +} + +func Test_MetricsService_GetMetricHttpHandler(t *testing.T) { + monitorService := &MonitorService{} + + mMonitorClient := &mockMonitorClient{} + monitorService.monitorClient = mMonitorClient + + t.Run("running HttpServe with metric http handler", func(t *testing.T) { + mHttpHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"status": "OK"}`)) + require.NoError(t, err) + }) + mMonitorClient.On("GetMetricHttpHandler").Return(mHttpHandler).Once() + + httpHandler, err := monitorService.GetMetricHttpHandler() + require.NoError(t, err) + + r := chi.NewRouter() + r.Get("/metrics", httpHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + wantJson := `{"status": "OK"}` + assert.JSONEq(t, wantJson, rr.Body.String()) + mMonitorClient.AssertExpectations(t) + }) + + t.Run("error monitor client not initialized", func(t *testing.T) { + monitorService.monitorClient = nil + + _, err := monitorService.GetMetricHttpHandler() + require.EqualError(t, err, "client was not initialized") + }) +} + +func Test_MetricsService_GetMetricType(t *testing.T) { + monitorService := &MonitorService{} + + mMonitorClient := &mockMonitorClient{} + monitorService.monitorClient = mMonitorClient + + t.Run("returns metric type", func(t *testing.T) { + mMonitorClient.On("GetMetricType").Return(MetricType("MOCKMETRICTYPE")).Once() + + metricType, err := monitorService.GetMetricType() + require.NoError(t, err) + + assert.Equal(t, MetricType("MOCKMETRICTYPE"), metricType) + mMonitorClient.AssertExpectations(t) + }) + + t.Run("error monitor client not initialized", func(t *testing.T) { + monitorService.monitorClient = nil + + _, err := monitorService.GetMetricType() + require.EqualError(t, err, "client was not initialized") + }) +} + +func Test_MetricsService_MonitorRequestTime(t *testing.T) { + monitorService := &MonitorService{} + + mMonitorClient := &mockMonitorClient{} + monitorService.monitorClient = mMonitorClient + + mLabels := HttpRequestLabels{ + Status: "200", + Route: "/mock", + Method: "get", + } + + mDuration := time.Duration(1) + + t.Run("monitor request time is called", func(t *testing.T) { + mMonitorClient.On("MonitorHttpRequestDuration", mDuration, mLabels).Once() + err := monitorService.MonitorHttpRequestDuration(mDuration, mLabels) + + require.NoError(t, err) + mMonitorClient.AssertExpectations(t) + }) + + t.Run("error monitor client not initialized", func(t *testing.T) { + monitorService.monitorClient = nil + + err := monitorService.MonitorHttpRequestDuration(mDuration, mLabels) + require.EqualError(t, err, "client was not initialized") + }) +} + +func Test_MetricsService_MonitorDBQueryDuration(t *testing.T) { + monitorService := &MonitorService{} + + mMonitorClient := &mockMonitorClient{} + monitorService.monitorClient = mMonitorClient + + mLabels := DBQueryLabels{ + QueryType: "SELECT", + } + + mDuration := time.Duration(1) + + mMetricTag := MetricTag("mock") + + t.Run("monitor db query duration is called", func(t *testing.T) { + mMonitorClient.On("MonitorDBQueryDuration", mDuration, mMetricTag, mLabels).Once() + err := monitorService.MonitorDBQueryDuration(mDuration, mMetricTag, mLabels) + + require.NoError(t, err) + mMonitorClient.AssertExpectations(t) + }) + + t.Run("error monitor client not initialized", func(t *testing.T) { + monitorService.monitorClient = nil + + err := monitorService.MonitorDBQueryDuration(mDuration, mMetricTag, mLabels) + require.EqualError(t, err, "client was not initialized") + }) +} + +func Test_MetricsService_MonitorCounter(t *testing.T) { + monitorService := &MonitorService{} + + mMonitorClient := &mockMonitorClient{} + monitorService.monitorClient = mMonitorClient + + mMetricTag := MetricTag("mock") + + t.Run("monitor counter is called without labels", func(t *testing.T) { + mMonitorClient.On("MonitorCounters", mMetricTag, map[string]string{}).Once() + err := monitorService.MonitorCounters(mMetricTag, map[string]string{}) + + require.NoError(t, err) + mMonitorClient.AssertExpectations(t) + }) + + t.Run("monitor counter is called with labels", func(t *testing.T) { + labelsMock := map[string]string{ + "mock": "mock_value", + } + + mMonitorClient.On("MonitorCounters", mMetricTag, labelsMock).Once() + err := monitorService.MonitorCounters(mMetricTag, labelsMock) + + require.NoError(t, err) + mMonitorClient.AssertExpectations(t) + }) + + t.Run("error monitor client not initialized", func(t *testing.T) { + monitorService.monitorClient = nil + + err := monitorService.MonitorCounters(mMetricTag, nil) + require.EqualError(t, err, "client was not initialized") + }) +} diff --git a/internal/monitor/prometheus_client.go b/internal/monitor/prometheus_client.go new file mode 100644 index 000000000..b6ed91b93 --- /dev/null +++ b/internal/monitor/prometheus_client.go @@ -0,0 +1,87 @@ +package monitor + +import ( + "fmt" + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/stellar/go/support/log" +) + +type prometheusClient struct { + httpHandler http.Handler +} + +func (prometheusClient) GetMetricType() MetricType { + return MetricTypePrometheus +} + +func (p *prometheusClient) GetMetricHttpHandler() http.Handler { + return p.httpHandler +} + +func (p *prometheusClient) MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) { + SummaryVecMetrics[HttpRequestDurationTag].With(prometheus.Labels{ + "status": labels.Status, + "route": labels.Route, + "method": labels.Method, + }).Observe(duration.Seconds()) +} + +func (p *prometheusClient) MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) { + summary := SummaryVecMetrics[tag] + summary.With(prometheus.Labels{ + "query_type": labels.QueryType, + }).Observe(duration.Seconds()) +} + +func (p *prometheusClient) MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) { + summary := SummaryVecMetrics[tag] + summary.With(labels).Observe(duration.Seconds()) +} + +func (p *prometheusClient) MonitorCounters(tag MetricTag, labels map[string]string) { + if len(labels) != 0 { + if counterVecMetric, ok := CounterVecMetrics[tag]; ok { + counterVecMetric.With(labels).Inc() + } else { + log.Errorf("metric not registered in prometheus metrics: %s", tag) + } + } else { + if counterMetric, ok := CounterMetrics[tag]; ok { + counterMetric.Inc() + } else { + log.Errorf("metric not registered in prometheus metrics: %s", tag) + } + } +} + +func (p *prometheusClient) MonitorHistogram(value float64, tag MetricTag, labels map[string]string) { + histogram := HistogramVecMetrics[tag] + histogram.With(labels).Observe(value) +} + +func NewPrometheusClient() (*prometheusClient, error) { + // register Prometheus metrics + metricsRegistry := prometheus.NewRegistry() + + var metricTag MetricTag + for _, tag := range metricTag.ListAll() { + if summaryVecMetric, ok := SummaryVecMetrics[tag]; ok { + metricsRegistry.MustRegister(summaryVecMetric) + } else if counterMetric, ok := CounterMetrics[tag]; ok { + metricsRegistry.MustRegister(counterMetric) + } else if counterVecMetric, ok := CounterVecMetrics[tag]; ok { + metricsRegistry.MustRegister(counterVecMetric) + } else { + return nil, fmt.Errorf("metric not registered in prometheus metrics: %s", tag) + } + } + + return &prometheusClient{httpHandler: promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{})}, nil +} + +// Ensuring that promtheusClient is implementing MonitorClient interface +var _ MonitorClient = (*prometheusClient)(nil) diff --git a/internal/monitor/prometheus_client_test.go b/internal/monitor/prometheus_client_test.go new file mode 100644 index 000000000..7e430849d --- /dev/null +++ b/internal/monitor/prometheus_client_test.go @@ -0,0 +1,243 @@ +package monitor + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PrometheusClient_GetMetricType(t *testing.T) { + mPrometheusClient := &prometheusClient{} + + metricType := mPrometheusClient.GetMetricType() + assert.Equal(t, MetricTypePrometheus, metricType) +} + +func Test_PrometheusClient_GetMetricHttpHandler(t *testing.T) { + mPrometheusClient := &prometheusClient{} + + mHttpHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"status": "OK"}`)) + require.NoError(t, err) + }) + + mPrometheusClient.httpHandler = mHttpHandler + + httpHandler := mPrometheusClient.GetMetricHttpHandler() + + r := chi.NewRouter() + r.Get("/metrics", httpHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + wantJson := `{"status": "OK"}` + assert.JSONEq(t, wantJson, rr.Body.String()) +} + +func Test_PrometheusClient_MonitorRequestTime(t *testing.T) { + mPrometheusClient := &prometheusClient{} + + metricsRegistry := prometheus.NewRegistry() + metricsRegistry.MustRegister(SummaryVecMetrics[HttpRequestDurationTag]) + + mPrometheusClient.httpHandler = promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) + + mLabels := HttpRequestLabels{ + Status: "200", + Route: "/mock", + Method: "GET", + } + + // initializing durations as 1 second + mDuration := time.Second * 1 + + mPrometheusClient.MonitorHttpRequestDuration(mDuration, mLabels) + + r := chi.NewRouter() + r.Get("/metrics", mPrometheusClient.httpHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + sumMetric := `sdp_http_requests_duration_seconds_sum{method="GET",route="/mock",status="200"} 1` + countMetric := `sdp_http_requests_duration_seconds_count{method="GET",route="/mock",status="200"} 1` + + assert.Contains(t, body, sumMetric) + assert.Contains(t, body, countMetric) +} + +func Test_PrometheusClient_MonitorDBQueryDuration(t *testing.T) { + mPrometheusClient := &prometheusClient{} + + metricsRegistry := prometheus.NewRegistry() + metricsRegistry.MustRegister(SummaryVecMetrics[SuccessfulQueryDurationTag]) + metricsRegistry.MustRegister(SummaryVecMetrics[FailureQueryDurationTag]) + + mPrometheusClient.httpHandler = promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) + + mLabels := DBQueryLabels{ + QueryType: "SELECT", + } + + // initializing durations as 1 second + mDuration := time.Second * 1 + + // setup metric handler + r := chi.NewRouter() + r.Get("/metrics", mPrometheusClient.httpHandler.ServeHTTP) + + t.Run("successful db query metric", func(t *testing.T) { + mPrometheusClient.MonitorDBQueryDuration(mDuration, SuccessfulQueryDurationTag, mLabels) + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + sumMetric := `sdp_db_successful_queries_duration_sum{query_type="SELECT"} 1` + countMetric := `sdp_db_successful_queries_duration_count{query_type="SELECT"} 1` + + assert.Contains(t, body, sumMetric) + assert.Contains(t, body, countMetric) + }) + + t.Run("failure db query metric", func(t *testing.T) { + mPrometheusClient.MonitorDBQueryDuration(mDuration, FailureQueryDurationTag, mLabels) + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + sumMetric := `sdp_db_failure_queries_duration_sum{query_type="SELECT"} 1` + countMetric := `sdp_db_failure_queries_duration_count{query_type="SELECT"} 1` + + assert.Contains(t, body, sumMetric) + assert.Contains(t, body, countMetric) + }) +} + +func Test_PrometheusClient_MonitorCounters(t *testing.T) { + mPrometheusClient := &prometheusClient{} + + metricsRegistry := prometheus.NewRegistry() + metricsRegistry.MustRegister(CounterVecMetrics[DisbursementsCounterTag]) + + mPrometheusClient.httpHandler = promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) + + r := chi.NewRouter() + r.Get("/metrics", mPrometheusClient.httpHandler.ServeHTTP) + + t.Run("disbursements counter metric", func(t *testing.T) { + labels := DisbursementLabels{ + Asset: "USDC", + Country: "UKR", + Wallet: "Mock Wallet", + } + + mPrometheusClient.MonitorCounters(DisbursementsCounterTag, labels.ToMap()) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + metric := `sdp_bussiness_disbursements_counter{asset="USDC",country="UKR",wallet="Mock Wallet"} 1` + + assert.Contains(t, body, metric) + + // redefining disbursements counter metrics to have no influence on other tests + CounterVecMetrics[DisbursementsCounterTag].Reset() + }) + + t.Run("counter vec metric not mapped on prometheus metrics", func(t *testing.T) { + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.ErrorLevel) + + labelsMock := map[string]string{ + "mock": "mock_value", + } + + mPrometheusClient.MonitorCounters(MetricTag("counter_vec_mock_tag"), labelsMock) + + require.Contains(t, buf.String(), `level=error msg="metric not registered in prometheus metrics: counter_vec_mock_tag`) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Empty(t, data) + }) + + t.Run("counter metric not mapped on prometheus metrics", func(t *testing.T) { + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.ErrorLevel) + + mPrometheusClient.MonitorCounters(MetricTag("counter_mock_tag"), nil) + + require.Contains(t, buf.String(), `level=error msg="metric not registered in prometheus metrics: counter_mock_tag`) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Empty(t, data) + }) + + // TO-DO add tests for counter metrics when these metrics are added in the app +} diff --git a/internal/monitor/prometheus_metrics.go b/internal/monitor/prometheus_metrics.go new file mode 100644 index 000000000..f4c2f97a3 --- /dev/null +++ b/internal/monitor/prometheus_metrics.go @@ -0,0 +1,37 @@ +package monitor + +import "github.com/prometheus/client_golang/prometheus" + +var SummaryVecMetrics = map[MetricTag]*prometheus.SummaryVec{ + HttpRequestDurationTag: prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Namespace: "sdp", Subsystem: "http", Name: string(HttpRequestDurationTag), + Help: "HTTP requests durations, sliding window = 10m", + }, + []string{"status", "route", "method"}, + ), + SuccessfulQueryDurationTag: prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Namespace: "sdp", Subsystem: "db", Name: string(SuccessfulQueryDurationTag), + Help: "Successful DB query durations", + }, + []string{"query_type"}, + ), + FailureQueryDurationTag: prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Namespace: "sdp", Subsystem: "db", Name: string(FailureQueryDurationTag), + Help: "Failure DB query durations", + }, + []string{"query_type"}, + ), +} + +var CounterMetrics map[MetricTag]prometheus.Counter + +var HistogramVecMetrics map[MetricTag]prometheus.HistogramVec + +var CounterVecMetrics = map[MetricTag]*prometheus.CounterVec{ + DisbursementsCounterTag: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "sdp", Subsystem: "bussiness", Name: string(DisbursementsCounterTag), + Help: "Disbursements Counter", + }, + []string{"asset", "country", "wallet"}, + ), +} diff --git a/internal/monitor/tss_metric_tags.go b/internal/monitor/tss_metric_tags.go new file mode 100644 index 000000000..cf7ecefd7 --- /dev/null +++ b/internal/monitor/tss_metric_tags.go @@ -0,0 +1,33 @@ +package monitor + +const ( + // Metric Tags + HorizonErrorCounterTag MetricTag = "error_count" + TransactionQueuedToCompletedLatencyTag MetricTag = "queued_to_completed_latency_seconds" + TransactionStartedToCompletedLatencyTag MetricTag = "started_to_completed_latency_seconds" + TransactionRetryCountTag MetricTag = "retry_count" + TransactionProcessedCounterTag MetricTag = "processed_count" + + // Metric Labels + TransactionStatusSuccessLabel string = "success" + TransactionStatusErrorLabel string = "error" + + TransactionErrorBuildFeeBumpLabel string = "building_feebump_txn" + TransactionErrorSignFeeBumpLebel string = "sign_feebump_txn" + TransactionErrorBuildPaymentLabel string = "building_payment_txn" + TransactionErrorSignPaymentLebel string = "sign_payment_txn" + TransactionErrorSubmitLabel string = "submitting_payment" + TransactionErrorInvalidStateLabel string = "invalid_state" + TransactionErrorHashingTxnLabel string = "hashing_txn" + TransactionErrorSavingHashLabel string = "saving_hash" +) + +func (m MetricTag) ListAllTSSMetricTags() []MetricTag { + return []MetricTag{ + HorizonErrorCounterTag, + TransactionQueuedToCompletedLatencyTag, + TransactionStartedToCompletedLatencyTag, + TransactionRetryCountTag, + TransactionProcessedCounterTag, + } +} diff --git a/internal/monitor/tss_prometheus_client.go b/internal/monitor/tss_prometheus_client.go new file mode 100644 index 000000000..82c029d22 --- /dev/null +++ b/internal/monitor/tss_prometheus_client.go @@ -0,0 +1,136 @@ +package monitor + +import ( + "fmt" + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/sirupsen/logrus" + "github.com/stellar/go/support/log" +) + +type tssPrometheusClient struct { + httpHandler http.Handler +} + +// Metrics is a logrus hook-compliant struct that records metrics about logging +// when added to a logrus.Logger +type Metrics map[logrus.Level]prometheus.Counter + +// Fire is triggered by logrus, in response to a logging event +func (m *Metrics) Fire(e *logrus.Entry) error { + (*m)[e.Level].Inc() + return nil +} + +// Levels returns the logging levels that will trigger this hook to run. In +// this case, all of them. +func (m *Metrics) Levels() []logrus.Level { + return []logrus.Level{ + logrus.WarnLevel, + logrus.ErrorLevel, + logrus.PanicLevel, + } +} + +func (tssPrometheusClient) GetMetricType() MetricType { + return MetricTypeTSSPrometheus +} + +func (p *tssPrometheusClient) GetMetricHttpHandler() http.Handler { + return p.httpHandler +} + +func (p *tssPrometheusClient) MonitorHttpRequestDuration(duration time.Duration, labels HttpRequestLabels) { + SummaryTSSVecMetrics[HttpRequestDurationTag].With(prometheus.Labels{ + "status": labels.Status, + "route": labels.Route, + "method": labels.Method, + }).Observe(duration.Seconds()) +} + +func (p *tssPrometheusClient) MonitorDBQueryDuration(duration time.Duration, tag MetricTag, labels DBQueryLabels) { + summary := SummaryTSSVecMetrics[tag] + summary.With(prometheus.Labels{ + "query_type": labels.QueryType, + }).Observe(duration.Seconds()) +} + +func (p *tssPrometheusClient) MonitorDuration(duration time.Duration, tag MetricTag, labels map[string]string) { + summary := SummaryTSSVecMetrics[tag] + summary.With(labels).Observe(duration.Seconds()) +} + +func (p *tssPrometheusClient) MonitorCounters(tag MetricTag, labels map[string]string) { + if len(labels) != 0 { + if counterVecMetric, ok := CounterTSSVecMetrics[tag]; ok { + counterVecMetric.With(labels).Inc() + } else { + log.Errorf("metric not registered in prometheus metrics: %s", tag) + } + } else { + if counterMetric, ok := CounterTSSMetrics[tag]; ok { + counterMetric.Inc() + } else { + log.Errorf("metric not registered in prometheus metrics: %s", tag) + } + } +} + +func (p *tssPrometheusClient) MonitorHistogram(value float64, tag MetricTag, labels map[string]string) { + histogram := HistogramTSSVecMetrics[tag] + histogram.With(labels).Observe(value) +} + +// NewTSSPrometheusClient registers Prometheus metrics for the Transaction Submission Service +func NewTSSPrometheusClient() (*tssPrometheusClient, error) { + // register Prometheus metrics + metricsRegistry := prometheus.NewRegistry() + + // register default Prometheus metrics + metricsRegistry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + metricsRegistry.MustRegister(collectors.NewGoCollector()) + + var tssMetricTag MetricTag + for _, tag := range tssMetricTag.ListAllTSSMetricTags() { + if summaryTSSVecMetric, ok := SummaryTSSVecMetrics[tag]; ok { + metricsRegistry.MustRegister(summaryTSSVecMetric) + } else if counterTSSMetric, ok := CounterTSSMetrics[tag]; ok { + metricsRegistry.MustRegister(counterTSSMetric) + } else if counterTSSVecMetric, ok := CounterTSSVecMetrics[tag]; ok { + metricsRegistry.MustRegister(counterTSSVecMetric) + } else if histogramTSSVecMetric, ok := HistogramTSSVecMetrics[tag]; ok { + metricsRegistry.MustRegister(histogramTSSVecMetric) + } else { + return nil, fmt.Errorf("metric not registered in prometheus metrics: %s", tag) + } + } + + // create a logging hook that increments a Prometheus counter for each log level + logCounterHook := &Metrics{ + logrus.WarnLevel: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "tss", Subsystem: "log", Name: "warn_total", + }), + logrus.ErrorLevel: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "tss", Subsystem: "log", Name: "error_total", + }), + logrus.PanicLevel: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "tss", Subsystem: "log", Name: "panic_total", + }), + } + + for _, metric := range *logCounterHook { + metricsRegistry.MustRegister(metric) + } + + // add the logCounterHook to the logger + log.DefaultLogger.AddHook(logCounterHook) + + return &tssPrometheusClient{httpHandler: promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{})}, nil +} + +// Ensuring that promtheusClient is implementing MonitorClient interface +var _ MonitorClient = (*tssPrometheusClient)(nil) diff --git a/internal/monitor/tss_prometheus_client_test.go b/internal/monitor/tss_prometheus_client_test.go new file mode 100644 index 000000000..3e1c4d2a2 --- /dev/null +++ b/internal/monitor/tss_prometheus_client_test.go @@ -0,0 +1,277 @@ +package monitor + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_TSSPrometheusClient_GetMetricType(t *testing.T) { + mTSSPrometheusClient := &tssPrometheusClient{} + + metricType := mTSSPrometheusClient.GetMetricType() + assert.Equal(t, MetricTypeTSSPrometheus, metricType) +} + +func Test_TSSPrometheusClient_GetMetricHttpHandler(t *testing.T) { + mTSSPrometheusClient := &tssPrometheusClient{} + + mHttpHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"status": "OK"}`)) + require.NoError(t, err) + }) + + mTSSPrometheusClient.httpHandler = mHttpHandler + + httpHandler := mTSSPrometheusClient.GetMetricHttpHandler() + + r := chi.NewRouter() + r.Get("/metrics", httpHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + wantJson := `{"status": "OK"}` + assert.JSONEq(t, wantJson, rr.Body.String()) +} + +func Test_TSSPrometheusClient_MonitorDBQueryDuration(t *testing.T) { + mTSSPrometheusClient := &tssPrometheusClient{} + + metricsRegistry := prometheus.NewRegistry() + metricsRegistry.MustRegister(SummaryTSSVecMetrics[SuccessfulQueryDurationTag]) + metricsRegistry.MustRegister(SummaryTSSVecMetrics[FailureQueryDurationTag]) + + mTSSPrometheusClient.httpHandler = promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) + + mLabels := DBQueryLabels{ + QueryType: "SELECT", + } + + // initializing durations as 1 second + mDuration := time.Second * 1 + + // setup metric handler + r := chi.NewRouter() + r.Get("/metrics", mTSSPrometheusClient.httpHandler.ServeHTTP) + + t.Run("successful db query metric", func(t *testing.T) { + mTSSPrometheusClient.MonitorDBQueryDuration(mDuration, SuccessfulQueryDurationTag, mLabels) + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + sumMetric := `tss_db_successful_queries_duration_sum{query_type="SELECT"} 1` + countMetric := `tss_db_successful_queries_duration_count{query_type="SELECT"} 1` + + assert.Contains(t, body, sumMetric) + assert.Contains(t, body, countMetric) + }) + + t.Run("failure db query metric", func(t *testing.T) { + mTSSPrometheusClient.MonitorDBQueryDuration(mDuration, FailureQueryDurationTag, mLabels) + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + sumMetric := `tss_db_failure_queries_duration_sum{query_type="SELECT"} 1` + countMetric := `tss_db_failure_queries_duration_count{query_type="SELECT"} 1` + + assert.Contains(t, body, sumMetric) + assert.Contains(t, body, countMetric) + }) +} + +func Test_TSSPrometheusClient_MonitorCounters(t *testing.T) { + mTSSPrometheusClient := &tssPrometheusClient{} + + metricsRegistry := prometheus.NewRegistry() + metricsRegistry.MustRegister(CounterTSSVecMetrics[TransactionProcessedCounterTag]) + metricsRegistry.MustRegister(CounterTSSVecMetrics[HorizonErrorCounterTag]) + + mTSSPrometheusClient.httpHandler = promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) + + r := chi.NewRouter() + r.Get("/metrics", mTSSPrometheusClient.httpHandler.ServeHTTP) + + t.Run("transactions processed counter metric", func(t *testing.T) { + labels := map[string]string{ + "result": "success", + "error_type": "none", + "retried": "false", + } + + mTSSPrometheusClient.MonitorCounters(TransactionProcessedCounterTag, labels) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + metric := `tss_tx_processing_processed_count{error_type="none",result="success",retried="false"} 1` + + assert.Contains(t, body, metric) + + CounterTSSVecMetrics[TransactionProcessedCounterTag].Reset() + }) + + t.Run("horizon errors counter metric", func(t *testing.T) { + labels := map[string]string{ + "status_code": "123", + "result_code": "321", + } + + mTSSPrometheusClient.MonitorCounters(HorizonErrorCounterTag, labels) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + metric := `tss_horizon_client_error_count{result_code="321",status_code="123"} 1` + + assert.Contains(t, body, metric) + + CounterTSSVecMetrics[HorizonErrorCounterTag].Reset() + }) +} + +func Test_TSSPrometheusClient_MonitorHistogram(t *testing.T) { + mTSSPrometheusClient := &tssPrometheusClient{} + + metricsRegistry := prometheus.NewRegistry() + metricsRegistry.MustRegister(HistogramTSSVecMetrics[TransactionRetryCountTag]) + metricsRegistry.MustRegister(HistogramTSSVecMetrics[TransactionQueuedToCompletedLatencyTag]) + metricsRegistry.MustRegister(HistogramTSSVecMetrics[TransactionStartedToCompletedLatencyTag]) + + mTSSPrometheusClient.httpHandler = promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) + + r := chi.NewRouter() + r.Get("/metrics", mTSSPrometheusClient.httpHandler.ServeHTTP) + + t.Run("transactions processed retry_count histogram metric", func(t *testing.T) { + labels := map[string]string{ + "result": "success", + "error_type": "none", + "retried": "false", + } + + mTSSPrometheusClient.MonitorHistogram(float64(3), TransactionRetryCountTag, labels) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + metric := `tss_tx_processing_retry_count_bucket{error_type="none",result="success",retried="false",le="3"} 1` + + assert.Contains(t, body, metric) + + HistogramTSSVecMetrics[TransactionRetryCountTag].Reset() + }) + + t.Run("transactions processed queued_to_completed_latency_seconds histogram metric", func(t *testing.T) { + labels := map[string]string{ + "result": "success", + "error_type": "none", + "retried": "false", + } + + mTSSPrometheusClient.MonitorHistogram(float64(15), TransactionQueuedToCompletedLatencyTag, labels) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + metric := `tss_tx_processing_queued_to_completed_latency_seconds_bucket{error_type="none",result="success",retried="false",le="15"} 1` + + assert.Contains(t, body, metric) + + HistogramTSSVecMetrics[TransactionQueuedToCompletedLatencyTag].Reset() + }) + + t.Run("transactions processed started_to_completed_latency_seconds histogram metric", func(t *testing.T) { + labels := map[string]string{ + "result": "success", + "error_type": "none", + "retried": "false", + } + + mTSSPrometheusClient.MonitorHistogram(float64(15), TransactionStartedToCompletedLatencyTag, labels) + + req, err := http.NewRequest("GET", "/metrics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotEmpty(t, data) + body := string(data) + + metric := `tss_tx_processing_started_to_completed_latency_seconds_bucket{error_type="none",result="success",retried="false",le="15"} 1` + + assert.Contains(t, body, metric) + + HistogramTSSVecMetrics[TransactionStartedToCompletedLatencyTag].Reset() + }) +} diff --git a/internal/monitor/tss_prometheus_metrics.go b/internal/monitor/tss_prometheus_metrics.go new file mode 100644 index 000000000..2fd29447e --- /dev/null +++ b/internal/monitor/tss_prometheus_metrics.go @@ -0,0 +1,73 @@ +package monitor + +import "github.com/prometheus/client_golang/prometheus" + +var HistogramTSSVecMetrics = map[MetricTag]*prometheus.HistogramVec{ + TransactionQueuedToCompletedLatencyTag: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "tss", + Subsystem: "tx_processing", + Name: string(TransactionQueuedToCompletedLatencyTag), + Help: "Latency (seconds) taken from when a Transaction was created to when it completed (Success/Error status)", + Buckets: prometheus.LinearBuckets(5, 5, 24), // 5 seconds to 2 minutes + }, + []string{"retried", "result", "error_type"}, + ), + TransactionStartedToCompletedLatencyTag: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "tss", + Subsystem: "tx_processing", + Name: string(TransactionStartedToCompletedLatencyTag), + Help: "Latency (seconds) taken from when a Transaction was started to when it completed (Success/Error status)", + Buckets: prometheus.LinearBuckets(5, 5, 24), + }, + []string{"retried", "result", "error_type"}, + ), + TransactionRetryCountTag: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "tss", + Subsystem: "tx_processing", + Name: string(TransactionRetryCountTag), + Help: "Transaction retry count", + Buckets: prometheus.LinearBuckets(1, 1, 3), // 1 to 3 retries + }, + []string{"retried", "result", "error_type"}, + ), +} + +var SummaryTSSVecMetrics = map[MetricTag]*prometheus.SummaryVec{ + SuccessfulQueryDurationTag: prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Namespace: "tss", + Subsystem: "db", + Name: string(SuccessfulQueryDurationTag), + Help: "Successful DB query durations", + }, + []string{"query_type"}, + ), + FailureQueryDurationTag: prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Namespace: "tss", + Subsystem: "db", + Name: string(FailureQueryDurationTag), + Help: "Failure DB query durations", + }, + []string{"query_type"}, + ), +} + +var CounterTSSMetrics = map[MetricTag]prometheus.Counter{} + +var CounterTSSVecMetrics = map[MetricTag]*prometheus.CounterVec{ + TransactionProcessedCounterTag: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "tss", + Subsystem: "tx_processing", + Name: string(TransactionProcessedCounterTag), + Help: "Count of transactions processed by TSS", + }, + []string{"retried", "result", "error_type"}, + ), + HorizonErrorCounterTag: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "tss", + Subsystem: "horizon_client", + Name: string(HorizonErrorCounterTag), + Help: "Count of Horizon related errors", + }, + []string{"status_code", "result_code"}, + ), +} diff --git a/internal/scheduler/jobs/job.go b/internal/scheduler/jobs/job.go new file mode 100644 index 000000000..bdaf56c55 --- /dev/null +++ b/internal/scheduler/jobs/job.go @@ -0,0 +1,12 @@ +package jobs + +import ( + "context" + "time" +) + +type Job interface { + Execute(context.Context) error + GetInterval() time.Duration + GetName() string +} diff --git a/internal/scheduler/jobs/payments_processor_job.go b/internal/scheduler/jobs/payments_processor_job.go new file mode 100644 index 000000000..361315c5e --- /dev/null +++ b/internal/scheduler/jobs/payments_processor_job.go @@ -0,0 +1,44 @@ +package jobs + +import ( + "context" + "fmt" + "time" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" +) + +type PaymentsProcessorJob struct { + service services.SendPaymentsServiceInterface +} + +const ( + PaymentJobName = "payments_processor_job" + PaymentsJobIntervalSeconds = 60 + PaymentsBatchSize = 100 +) + +func NewPaymentsProcessorJob(models *data.Models) *PaymentsProcessorJob { + return &PaymentsProcessorJob{service: services.NewSendPaymentsService(models)} +} + +func (d PaymentsProcessorJob) GetInterval() time.Duration { + return PaymentsJobIntervalSeconds * time.Second +} + +func (d PaymentsProcessorJob) GetName() string { + return PaymentJobName +} + +func (d PaymentsProcessorJob) Execute(ctx context.Context) error { + log.Ctx(ctx).Infof("executing PaymentsProcessorJob ...") + err := d.service.SendBatchPayments(ctx, PaymentsBatchSize) + if err != nil { + return fmt.Errorf("error executing PaymentsProcessorJob: %w", err) + } + return nil +} + +var _ Job = new(PaymentsProcessorJob) diff --git a/internal/scheduler/jobs/payments_processor_job_test.go b/internal/scheduler/jobs/payments_processor_job_test.go new file mode 100644 index 000000000..57c80d1d8 --- /dev/null +++ b/internal/scheduler/jobs/payments_processor_job_test.go @@ -0,0 +1,74 @@ +package jobs + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockSendPaymentsService mocks SendPaymentsService +type MockSendPaymentsService struct { + mock.Mock +} + +func (m *MockSendPaymentsService) SendBatchPayments(ctx context.Context, batchSize int) error { + args := m.Called(ctx, batchSize) + return args.Error(0) +} + +func Test_PaymentsProcessorJob_GetInterval(t *testing.T) { + p := NewPaymentsProcessorJob(&data.Models{}) + require.Equal(t, PaymentsJobIntervalSeconds*time.Second, p.GetInterval()) +} + +func Test_PaymentsProcessorJob_GetName(t *testing.T) { + p := NewPaymentsProcessorJob(&data.Models{}) + require.Equal(t, PaymentJobName, p.GetName()) +} + +func Test_PaymentsProcessorJob_Execute(t *testing.T) { + tests := []struct { + name string + sendPayments func(ctx context.Context, batchSize int) error + wantErr bool + }{ + { + name: "SendBatchPayments success", + sendPayments: func(ctx context.Context, batchSize int) error { + return nil + }, + wantErr: false, + }, + { + name: "SendBatchPayments returns error", + sendPayments: func(ctx context.Context, batchSize int) error { + return fmt.Errorf("error") + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSendPaymentsService := &MockSendPaymentsService{} + mockSendPaymentsService.On("SendBatchPayments", mock.Anything, PaymentsBatchSize). + Return(tt.sendPayments(nil, PaymentsBatchSize)) + + p := PaymentsProcessorJob{ + service: mockSendPaymentsService, + } + + err := p.Execute(context.Background()) + if (err != nil) != tt.wantErr { + t.Errorf("PaymentsProcessorJob.Execute() error = %v, wantErr %v", err, tt.wantErr) + } + + mockSendPaymentsService.AssertExpectations(t) + }) + } +} diff --git a/internal/scheduler/jobs/send_receiver_wallets_sms_invitation_job.go b/internal/scheduler/jobs/send_receiver_wallets_sms_invitation_job.go new file mode 100644 index 000000000..89b41c271 --- /dev/null +++ b/internal/scheduler/jobs/send_receiver_wallets_sms_invitation_job.go @@ -0,0 +1,68 @@ +package jobs + +import ( + "context" + "fmt" + "time" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" +) + +const ( + SendReceiverWalletsSMSInvitationJobName = "send_receiver_wallets_sms_invitation_job" + SendReceiverWalletsSMSInvitationJobIntervalSeconds = 5 +) + +type SendReceiverWalletsSMSInvitationJobOptions struct { + AnchorPlatformBaseSepURL string + Models *data.Models + MessengerClient message.MessengerClient + MinDaysBetweenRetries int + MaxRetries int + Sep10SigningPrivateKey string + CrashTrackerClient crashtracker.CrashTrackerClient +} + +type SendReceiverWalletsSMSInvitationJob struct { + service *services.SendReceiverWalletInviteService +} + +func (j SendReceiverWalletsSMSInvitationJob) GetName() string { + return SendReceiverWalletsSMSInvitationJobName +} + +func (j SendReceiverWalletsSMSInvitationJob) GetInterval() time.Duration { + return time.Second * SendReceiverWalletsSMSInvitationJobIntervalSeconds +} + +func (j SendReceiverWalletsSMSInvitationJob) Execute(ctx context.Context) error { + if err := j.service.SendInvite(ctx); err != nil { + err = fmt.Errorf("error sending invitation SMS to receiver wallets: %w", err) + log.Ctx(ctx).Error(err) + return err + } + return nil +} + +func NewSendReceiverWalletsSMSInvitationJob(options SendReceiverWalletsSMSInvitationJobOptions) *SendReceiverWalletsSMSInvitationJob { + s, err := services.NewSendReceiverWalletInviteService( + options.Models, + options.MessengerClient, + options.AnchorPlatformBaseSepURL, + options.Sep10SigningPrivateKey, + options.MinDaysBetweenRetries, + options.MaxRetries, + options.CrashTrackerClient, + ) + if err != nil { + log.Fatalf("error instantiating service: %s", err.Error()) + } + + return &SendReceiverWalletsSMSInvitationJob{service: s} +} + +var _ Job = new(SendReceiverWalletsSMSInvitationJob) diff --git a/internal/scheduler/jobs/send_receiver_wallets_sms_invitation_job_test.go b/internal/scheduler/jobs/send_receiver_wallets_sms_invitation_job_test.go new file mode 100644 index 000000000..a377ce36d --- /dev/null +++ b/internal/scheduler/jobs/send_receiver_wallets_sms_invitation_job_test.go @@ -0,0 +1,292 @@ +package jobs + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewSendReceiverWalletsSMSInvitationJob(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + anchorPlatformBaseSepURL := "http://localhost:8000" + + messageDryRunClient, err := message.NewDryRunClient() + require.NoError(t, err) + + t.Run("exits with status 1 when Messenger Client is missing config", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + o := SendReceiverWalletsSMSInvitationJobOptions{ + Models: models, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + MinDaysBetweenRetries: 3, + MaxRetries: 3, + } + + NewSendReceiverWalletsSMSInvitationJob(o) + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("exits with status 1 when Base URL is empty", func(t *testing.T) { + if os.Getenv("TEST_FATAL") == "1" { + o := SendReceiverWalletsSMSInvitationJobOptions{ + Models: models, + MessengerClient: messageDryRunClient, + AnchorPlatformBaseSepURL: "", + MinDaysBetweenRetries: 3, + MaxRetries: 3, + } + + NewSendReceiverWalletsSMSInvitationJob(o) + return + } + + // We're using a strategy to setup a cmd inside the test that calls the test itself and verifies if it exited with exit status '1'. + // Ref: https://go.dev/talks/2014/testing.slide#23 + cmd := exec.Command(os.Args[0], fmt.Sprintf("-test.run=%s", t.Name())) + cmd.Env = append(os.Environ(), "TEST_FATAL=1") + + err := cmd.Run() + if exitError, ok := err.(*exec.ExitError); ok { + assert.False(t, exitError.Success()) + return + } + + t.Fatalf("process ran with err %v, want exit status 1", err) + }) + + t.Run("returns a job instance successfully", func(t *testing.T) { + o := SendReceiverWalletsSMSInvitationJobOptions{ + Models: models, + MessengerClient: messageDryRunClient, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + MinDaysBetweenRetries: 3, + MaxRetries: 3, + } + + j := NewSendReceiverWalletsSMSInvitationJob(o) + + assert.NotNil(t, j) + }) +} + +func Test_SendReceiverWalletsSMSInvitationJob(t *testing.T) { + j := SendReceiverWalletsSMSInvitationJob{} + + assert.Equal(t, SendReceiverWalletsSMSInvitationJobName, j.GetName()) + assert.Equal(t, SendReceiverWalletsSMSInvitationJobIntervalSeconds*time.Second, j.GetInterval()) +} + +func Test_SendReceiverWalletsSMSInvitationJob_Execute(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + anchorPlatformBaseSepURL := "http://localhost:8000" + stellarSecretKey := "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5" + + minDaysBetweenRetries := 3 + maxRetries := 3 + + ctx := context.Background() + + t.Run("executes the service successfully", func(t *testing.T) { + messengerClientMock := &message.MessengerClientMock{} + crashTrackerClientMock := &crashtracker.MockCrashTrackerClient{} + + s, err := services.NewSendReceiverWalletInviteService( + models, + messengerClientMock, + anchorPlatformBaseSepURL, + stellarSecretKey, + minDaysBetweenRetries, + maxRetries, + crashTrackerClientMock, + ) + require.NoError(t, err) + + data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool) + data.DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "ATL", "Atlantis") + + wallet1 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet1", "https://wallet1.com", "www.wallet1.com", "wallet1://sdp") + wallet2 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet2", "https://wallet2.com", "www.wallet2.com", "wallet2://sdp") + + asset1 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "FOO1", "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX") + asset2 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "FOO2", "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX") + + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Country: country, + Wallet: wallet1, + Status: data.ReadyDisbursementStatus, + Asset: asset1, + }) + + disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Country: country, + Wallet: wallet2, + Status: data.ReadyDisbursementStatus, + Asset: asset2, + }) + + rec1RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, data.ReadyReceiversWalletStatus) + data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet2.ID, data.RegisteredReceiversWalletStatus) + + rec2RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet2.ID, data.ReadyReceiversWalletStatus) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement1, + Asset: *asset1, + ReceiverWallet: rec1RW, + Amount: "1", + }) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement2, + Asset: *asset2, + ReceiverWallet: rec2RW, + Amount: "1", + }) + + walletDeepLink1 := services.WalletDeepLink{ + DeepLink: wallet1.DeepLinkSchema, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + OrganizationName: "MyCustomAid", + AssetCode: asset1.Code, + AssetIssuer: asset1.Issuer, + } + deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + contentWallet1 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink1) + + walletDeepLink2 := services.WalletDeepLink{ + DeepLink: wallet2.DeepLinkSchema, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + OrganizationName: "MyCustomAid", + AssetCode: asset2.Code, + AssetIssuer: asset2.Issuer, + } + deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + contentWallet2 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink2) + + mockErr := errors.New("unexpected error") + messengerClientMock. + On("SendMessage", message.Message{ + ToPhoneNumber: receiver1.PhoneNumber, + Message: contentWallet1, + }). + Return(mockErr). + Once(). + On("SendMessage", message.Message{ + ToPhoneNumber: receiver2.PhoneNumber, + Message: contentWallet2, + }). + Return(nil). + Once(). + On("MessengerType"). + Return(message.MessengerTypeTwilioSMS) + + mockMsg := fmt.Sprintf( + "error sending message to receiver ID %s for receiver wallet ID %s using messenger type %s", + receiver1.ID, rec1RW.ID, message.MessengerTypeTwilioSMS, + ) + crashTrackerClientMock.On("LogAndReportErrors", ctx, mockErr, mockMsg).Once() + + err = s.SendInvite(ctx) + require.NoError(t, err) + + q := ` + SELECT + type, status, receiver_id, wallet_id, receiver_wallet_id, + title_encrypted, text_encrypted, status_history + FROM + messages + WHERE + receiver_id = $1 AND wallet_id = $2 AND receiver_wallet_id = $3 + ` + var msg data.Message + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver1.ID, wallet1.ID, rec1RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver1.ID, msg.ReceiverID) + assert.Equal(t, wallet1.ID, msg.WalletID) + assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.FailureMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet1, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.FailureMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + + msg = data.Message{} + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver2.ID, wallet2.ID, rec2RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver2.ID, msg.ReceiverID) + assert.Equal(t, wallet2.ID, msg.WalletID) + assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.SuccessMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet2, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.SuccessMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + }) +} diff --git a/internal/scheduler/jobs/tss_monitor_job.go b/internal/scheduler/jobs/tss_monitor_job.go new file mode 100644 index 000000000..6c5851b54 --- /dev/null +++ b/internal/scheduler/jobs/tss_monitor_job.go @@ -0,0 +1,45 @@ +package jobs + +import ( + "context" + "fmt" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type TSSMonitorJob struct { + service *services.TSSMonitorService +} + +const ( + TSSMonitorJobName = "tss_monitor_job" + TSSMonitorJobIntervalSeconds = 60 + TSSMonitorBatchSize = 100 +) + +func NewTSSMonitorJob(models *data.Models) *TSSMonitorJob { + return &TSSMonitorJob{service: services.NewTSSMonitorService(models)} +} + +func (d TSSMonitorJob) GetInterval() time.Duration { + return TSSMonitorJobIntervalSeconds * time.Second +} + +func (d TSSMonitorJob) GetName() string { + return TSSMonitorJobName +} + +func (d TSSMonitorJob) Execute(ctx context.Context) error { + log.Ctx(ctx).Infof("executing TSSMonitorJob ...") + err := d.service.MonitorTransactions(ctx, TSSMonitorBatchSize) + if err != nil { + return fmt.Errorf("error executing TSSMonitorJob: %w", err) + } + return nil +} + +var _ Job = new(PaymentsProcessorJob) diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go new file mode 100644 index 000000000..82888e011 --- /dev/null +++ b/internal/scheduler/scheduler.go @@ -0,0 +1,169 @@ +package scheduler + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/scheduler/jobs" + + "github.com/stellar/go/support/log" +) + +// Scheduler manages a list of jobs and executes them at their specified intervals. +// It uses a job queue to distribute jobs to workers. +type Scheduler struct { + jobs map[string]jobs.Job + cancel context.CancelFunc + jobQueue chan jobs.Job + crashTrackerClient crashtracker.CrashTrackerClient +} + +type SchedulerOptions struct { + MinDaysBetweenRetries int + MaxRetries int +} + +type SchedulerJobRegisterOption func(*Scheduler) + +// SchedulerWorkerCount is the number of workers that will be started to process jobs +const SchedulerWorkerCount = 5 + +// StartScheduler initializes and starts the scheduler. This method blocks until the scheduler is stopped. +func StartScheduler(crashTrackerClient crashtracker.CrashTrackerClient, schedulerJobRegisters ...SchedulerJobRegisterOption) { + // Call crash tracker FlushEvents to flush buffered events before the scheduler terminates + defer crashTrackerClient.FlushEvents(2 * time.Second) + // Call crash tracker Recover for recover from unhandled panics + defer crashTrackerClient.Recover() + + ctx, cancel := context.WithCancel(context.Background()) + + // create a channel to listen for a shutdown signal + signalChan := make(chan os.Signal, 1) + + // register signal listeners for graceful shutdown + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + + scheduler := newScheduler(cancel) + // add crashTrackerClient to scheduler object + scheduler.crashTrackerClient = crashTrackerClient + + // Registering jobs + for _, schedulerJobRegister := range schedulerJobRegisters { + schedulerJobRegister(scheduler) + } + + scheduler.start(ctx) + + // wait for the shutdown signal here. + <-signalChan + + scheduler.stop() +} + +// newScheduler creates a new scheduler. +func newScheduler(cancel context.CancelFunc) *Scheduler { + return &Scheduler{ + jobs: make(map[string]jobs.Job), + cancel: cancel, + jobQueue: make(chan jobs.Job), + } +} + +// addJob adds a job to the scheduler. This method does not start the job. To start the job, call start(). +func (s *Scheduler) addJob(job jobs.Job) { + s.jobs[job.GetName()] = job +} + +// start starts the scheduler and all jobs. This method blocks until the scheduler is stopped. +func (s *Scheduler) start(ctx context.Context) { + if len(s.jobs) == 0 { + log.Ctx(ctx).Info("No jobs to start") + s.stop() + return + } + log.Ctx(ctx).Infof("Starting scheduler with %d workers...", SchedulerWorkerCount) + + // 1. We start all the workers that will process jobs from the job queue. + for i := 1; i <= SchedulerWorkerCount; i++ { + // start a new worker passing a CrashTrackerClient clone to report errors when the job is executed + go worker(ctx, i, s.crashTrackerClient.Clone(), s.jobQueue) + } + + // 2. Enqueue jobs to jobQueue. + // We start one goroutine per job but these are lightweight because they only wait for the ticker to tick then enqueue the job. + for _, job := range s.jobs { + go func(job jobs.Job) { + ticker := time.NewTicker(job.GetInterval()) + for { + select { + case <-ticker.C: + log.Infof("Enqueuing job: %s", job.GetName()) + s.jobQueue <- job + case <-ctx.Done(): + ticker.Stop() + return + } + } + }(job) + } +} + +// stop uses the context to stop the scheduler and all jobs. +func (s *Scheduler) stop() { + log.Info("Stopping scheduler...") + s.cancel() +} + +// worker is a goroutine that processes jobs from the job queue. +func worker(ctx context.Context, workerID int, crashTrackerClient crashtracker.CrashTrackerClient, jobQueue <-chan jobs.Job) { + defer func() { + if r := recover(); r != nil { + log.Errorf("Worker %d encountered a panic while processing a job: %v", workerID, r) + } + }() + for { + select { + case job := <-jobQueue: + log.Infof("Worker %d processing job: %s", workerID, job.GetName()) + if err := job.Execute(ctx); err != nil { + msg := fmt.Sprintf("error processing job %s on worker %d", job.GetName(), workerID) + // call crash tracker client to log and report error + crashTrackerClient.LogAndReportErrors(ctx, err, msg) + } + case <-ctx.Done(): + log.Infof("Worker %d stopping...", workerID) + return + } + } +} + +func WithPaymentsProcessorJobOption(models *data.Models) SchedulerJobRegisterOption { + return func(s *Scheduler) { + j := jobs.NewPaymentsProcessorJob(models) + log.Infof("registering %s job to scheduler", j.GetName()) + s.addJob(j) + } +} + +func WithTSSMonitorJobOption(models *data.Models) SchedulerJobRegisterOption { + return func(s *Scheduler) { + j := jobs.NewTSSMonitorJob(models) + log.Infof("registering %s job to scheduler", j.GetName()) + s.addJob(j) + } +} + +func WithSendReceiverWalletsSMSInvitationJobOption(o jobs.SendReceiverWalletsSMSInvitationJobOptions) SchedulerJobRegisterOption { + return func(s *Scheduler) { + j := jobs.NewSendReceiverWalletsSMSInvitationJob(o) + log.Infof("registering %s job to scheduler", j.GetName()) + s.addJob(j) + } +} diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go new file mode 100644 index 000000000..2b211ddc8 --- /dev/null +++ b/internal/scheduler/scheduler_test.go @@ -0,0 +1,82 @@ +package scheduler + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stretchr/testify/require" +) + +// MockJob is a mock job created for testing purposes +type MockJob struct { + name string + interval time.Duration + executions int + mu sync.Mutex +} + +func (m *MockJob) GetName() string { + return m.name +} + +func (m *MockJob) GetInterval() time.Duration { + return m.interval +} + +func (m *MockJob) Execute(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + m.executions++ + return nil +} + +func (m *MockJob) GetExecutions() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.executions +} + +func TestScheduler(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + scheduler := newScheduler(cancel) + + mockCrashTrackerClient := &crashtracker.MockCrashTrackerClient{} + scheduler.crashTrackerClient = mockCrashTrackerClient + + clone := crashtracker.MockCrashTrackerClient{} + mockCrashTrackerClient.On("Clone").Return(&clone).Times(5) + + mockJob1 := &MockJob{ + name: "mock_job_1", + interval: 1 * time.Second, + executions: 0, + } + + mockJob2 := &MockJob{ + name: "mock_job_2", + interval: 20 * time.Second, + executions: 0, + } + + scheduler.addJob(mockJob1) + scheduler.addJob(mockJob2) + + // Start the scheduler and wait for a short period to let the job run + scheduler.start(ctx) + time.Sleep(2 * time.Second) + + job1Executions := mockJob1.GetExecutions() + require.True(t, job1Executions > 0, "Expected job to be executed at least once, but it was executed %d times", job1Executions) + + job2Executions := mockJob2.GetExecutions() + require.True(t, job2Executions == 0, "Expected job to be executed 0 times, but it was executed %d times", job2Executions) + + // Test stopping the scheduler + cancel() + time.Sleep(1 * time.Second) + + mockCrashTrackerClient.AssertExpectations(t) +} diff --git a/internal/serve/httpclient/http_client.go b/internal/serve/httpclient/http_client.go new file mode 100644 index 000000000..8ed2acdfd --- /dev/null +++ b/internal/serve/httpclient/http_client.go @@ -0,0 +1,27 @@ +package httpclient + +import ( + "net/http" + "net/url" + "time" + + "github.com/stellar/go/clients/horizonclient" +) + +type HttpClientInterface interface { + Do(*http.Request) (*http.Response, error) + Get(url string) (resp *http.Response, err error) + PostForm(url string, data url.Values) (resp *http.Response, err error) +} + +const TimeoutClientInSeconds = 30 + +// DefaultClient returns a default HTTP client with a timeout. +func DefaultClient() HttpClientInterface { + return &http.Client{Timeout: TimeoutClientInSeconds * time.Second} +} + +var ( + _ HttpClientInterface = DefaultClient() + _ horizonclient.HTTP = DefaultClient() +) diff --git a/internal/serve/httpclient/http_client_mock.go b/internal/serve/httpclient/http_client_mock.go new file mode 100644 index 000000000..e2f979d7e --- /dev/null +++ b/internal/serve/httpclient/http_client_mock.go @@ -0,0 +1,42 @@ +package httpclient + +import ( + "net/http" + "net/url" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stretchr/testify/mock" +) + +type HttpClientMock struct { + mock.Mock +} + +func (h *HttpClientMock) Do(req *http.Request) (*http.Response, error) { + args := h.Called(req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} + +func (h *HttpClientMock) Get(url string) (*http.Response, error) { + args := h.Called(url) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} + +func (h *HttpClientMock) PostForm(url string, data url.Values) (*http.Response, error) { + args := h.Called(url, data) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} + +var ( + _ HttpClientInterface = (*HttpClientMock)(nil) + _ horizonclient.HTTP = (*HttpClientMock)(nil) +) diff --git a/internal/serve/httperror/httperror.go b/internal/serve/httperror/httperror.go new file mode 100644 index 000000000..64594d4cf --- /dev/null +++ b/internal/serve/httperror/httperror.go @@ -0,0 +1,128 @@ +package httperror + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" +) + +type HTTPError struct { + StatusCode int `json:"-"` + Message string `json:"error"` + // Extras contains extra information about the error. + Extras map[string]interface{} `json:"extras,omitempty"` + // Err is an optional field that can be used to wrap the original error to pass it forward. + Err error `json:"-"` +} + +// ReportFunc is a function type used to report unexpected errors. +type ReportErrorFunc func(ctx context.Context, err error, msg string) + +// ReportError is a struct type used to report unexpected errors. +type ReportError struct { + reportErrorFunc ReportErrorFunc +} + +// defaultReportFunc initiliaze defaultReportFunc with a default function. +var defaultReportErrorFunc = ReportError{ + reportErrorFunc: func(ctx context.Context, err error, msg string) { + if msg != "" { + err = fmt.Errorf("%s: %w", msg, err) + } + log.Ctx(ctx).WithStack(err).Errorf("%+v", err) + }, +} + +// SetDefaultReportErrorFunc sets a new defaultReportErrorFunc to report unexpected errors. +func SetDefaultReportErrorFunc(fn ReportErrorFunc) { + defaultReportErrorFunc.reportErrorFunc = fn +} + +func (h *HTTPError) Error() string { + return h.Message +} + +func (e *HTTPError) Unwrap() error { + return e.Err +} + +func (e *HTTPError) Render(w http.ResponseWriter) { + httpjson.RenderStatus(w, e.StatusCode, e, httpjson.JSON) +} + +func NewHTTPError(statusCode int, msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" && originalErr != nil && len(extras) == 0 { + var hErr *HTTPError + if errors.As(originalErr, &hErr) && (hErr.StatusCode == statusCode) { + return hErr + } + } + + return &HTTPError{ + StatusCode: statusCode, + Message: msg, + Extras: extras, + Err: originalErr, + } +} + +func NotFound(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "Resource not found." + } + return NewHTTPError(http.StatusNotFound, msg, originalErr, extras) +} + +func Conflict(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "The resource already exists." + } + return NewHTTPError(http.StatusConflict, msg, originalErr, extras) +} + +func BadRequest(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "The request was invalid in some way." + } + return NewHTTPError(http.StatusBadRequest, msg, originalErr, extras) +} + +func NotImplemented(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "This feature is not implemented yet." + } + return NewHTTPError(http.StatusNotImplemented, msg, originalErr, extras) +} + +func Unauthorized(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "Not authorized." + } + return NewHTTPError(http.StatusUnauthorized, msg, originalErr, extras) +} + +func Forbidden(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "You don't have permission to perform this action." + } + return NewHTTPError(http.StatusForbidden, msg, originalErr, extras) +} + +func InternalError(ctx context.Context, msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "An internal error occurred while processing this request." + } + defaultReportErrorFunc.reportErrorFunc(ctx, originalErr, msg) + return NewHTTPError(http.StatusInternalServerError, msg, originalErr, extras) +} + +func UnprocessableEntity(msg string, originalErr error, extras map[string]interface{}) *HTTPError { + if msg == "" { + msg = "Unprocessable entity." + } + return NewHTTPError(http.StatusUnprocessableEntity, msg, originalErr, extras) +} diff --git a/internal/serve/httperror/httperror_test.go b/internal/serve/httperror/httperror_test.go new file mode 100644 index 000000000..28fd0cb7c --- /dev/null +++ b/internal/serve/httperror/httperror_test.go @@ -0,0 +1,239 @@ +package httperror + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "strings" + "testing" + + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewHTTPError(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, "Bad request", nil, map[string]interface{}{ + "foo": "bar", + }) + + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + assert.Equal(t, "Bad request", err.Message) + assert.Len(t, err.Extras, 1) + assert.Equal(t, map[string]interface{}{"foo": "bar"}, err.Extras) +} + +func TestNewHTTPError_returnOriginalErrIfNoNewInfoWasAdded(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, "Bad request", nil, map[string]interface{}{ + "foo": "bar", + }) + + // if no new info was added, return original error + newErr := NewHTTPError(http.StatusBadRequest, "", err, nil) + assert.Equal(t, err, newErr) + + // return new error if the message changed + newErr = NewHTTPError(http.StatusBadRequest, "Foo Bar Bad Request", err, nil) + assert.NotEqual(t, err, newErr) + + // return new error if the status code changed + newErr = NewHTTPError(http.StatusNotFound, "", err, nil) + assert.NotEqual(t, err, newErr) + + // return new error if the extras have changed + newErr = NewHTTPError(http.StatusBadRequest, "", err, map[string]interface{}{ + "foo2": "bar2", + }) + assert.NotEqual(t, err, newErr) +} + +func TestNotFound(t *testing.T) { + originalErr := errors.New("original error") + + err := NotFound("", originalErr, map[string]interface{}{"foo": "not found"}) + assert.Equal(t, http.StatusNotFound, err.StatusCode) + assert.Equal(t, "Resource not found.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Equal(t, map[string]interface{}{"foo": "not found"}, err.Extras) + + err = NotFound("Foo Bar NotFound", nil, nil) + assert.Equal(t, http.StatusNotFound, err.StatusCode) + assert.Equal(t, "Foo Bar NotFound", err.Message) + assert.Nil(t, err.Err) + assert.Nil(t, err.Extras) +} + +func TestBadRequest(t *testing.T) { + originalErr := errors.New("original error") + + err := BadRequest("", originalErr, map[string]interface{}{"foo": "bad request"}) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + assert.Equal(t, "The request was invalid in some way.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Equal(t, map[string]interface{}{"foo": "bad request"}, err.Extras) + + err = BadRequest("Foo Bar BadRequest", nil, nil) + assert.Equal(t, http.StatusBadRequest, err.StatusCode) + assert.Equal(t, "Foo Bar BadRequest", err.Message) + assert.Nil(t, err.Err) + assert.Nil(t, err.Extras) +} + +func TestUnauthorized(t *testing.T) { + originalErr := errors.New("original error") + + err := Unauthorized("", originalErr, map[string]interface{}{"foo": "invalid token"}) + assert.Equal(t, http.StatusUnauthorized, err.StatusCode) + assert.Equal(t, "Not authorized.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Equal(t, map[string]interface{}{"foo": "invalid token"}, err.Extras) + + err = Unauthorized("Invalid token provided.", nil, nil) + assert.Equal(t, http.StatusUnauthorized, err.StatusCode) + assert.Equal(t, "Invalid token provided.", err.Message) + assert.Nil(t, err.Err) + assert.Nil(t, err.Extras) +} + +func TestForbidden(t *testing.T) { + originalErr := errors.New("original error") + + err := Forbidden("", originalErr, map[string]interface{}{"foo": "forbidden"}) + assert.Equal(t, http.StatusForbidden, err.StatusCode) + assert.Equal(t, "You don't have permission to perform this action.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Equal(t, map[string]interface{}{"foo": "forbidden"}, err.Extras) + + err = Forbidden("Foo Bar Forbidden", nil, nil) + assert.Equal(t, http.StatusForbidden, err.StatusCode) + assert.Equal(t, "Foo Bar Forbidden", err.Message) + assert.Nil(t, err.Err) + assert.Nil(t, err.Extras) +} + +func TestInternalError(t *testing.T) { + originalErr := errors.New("original error") + ctx := context.Background() + + t.Run("internal error with default message", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + err := InternalError(ctx, "", originalErr, map[string]interface{}{"foo": "bad server error"}) + assert.Equal(t, http.StatusInternalServerError, err.StatusCode) + assert.Equal(t, "An internal error occurred while processing this request.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Equal(t, map[string]interface{}{"foo": "bad server error"}, err.Extras) + + // validate logs + require.Contains(t, buf.String(), "An internal error occurred while processing this request.: original error") + }) + + t.Run("internal error with custom message", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + err := InternalError(ctx, "Foo Bar InternalError", originalErr, nil) + assert.Equal(t, http.StatusInternalServerError, err.StatusCode) + assert.Equal(t, "Foo Bar InternalError", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Nil(t, err.Extras) + + // validate logs + require.Contains(t, buf.String(), "Foo Bar InternalError: original error") + }) + + t.Run("internal error without error", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + err := InternalError(ctx, "", nil, nil) + assert.Equal(t, http.StatusInternalServerError, err.StatusCode) + assert.Equal(t, "An internal error occurred while processing this request.", err.Message) + assert.Nil(t, err.Err) + assert.Nil(t, err.Extras) + + // validate logs + require.Contains(t, buf.String(), "An internal error occurred while processing this request.:") + }) + + t.Run("internal error with custom ReportErrorFunc", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + reportErrorFunc := func(ctx context.Context, err error, msg string) { + log.Error("reported with custom ReportFunc") + } + + SetDefaultReportErrorFunc(reportErrorFunc) + + err := InternalError(ctx, "", originalErr, nil) + assert.Equal(t, http.StatusInternalServerError, err.StatusCode) + assert.Equal(t, "An internal error occurred while processing this request.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Nil(t, err.Extras) + + // validate logs + require.Contains(t, buf.String(), "reported with custom ReportFunc") + }) +} + +func TestUnprocessableEntity(t *testing.T) { + originalErr := errors.New("original error") + + err := UnprocessableEntity("", originalErr, map[string]interface{}{"foo": "invalid token"}) + assert.Equal(t, http.StatusUnprocessableEntity, err.StatusCode) + assert.Equal(t, "Unprocessable entity.", err.Message) + assert.Equal(t, originalErr, err.Err) + assert.Equal(t, map[string]interface{}{"foo": "invalid token"}, err.Extras) + + err = UnprocessableEntity("Could not process your request.", nil, nil) + assert.Equal(t, http.StatusUnprocessableEntity, err.StatusCode) + assert.Equal(t, "Could not process your request.", err.Message) + assert.Nil(t, err.Err) + assert.Nil(t, err.Extras) +} + +func TestNewHTTPError_json(t *testing.T) { + httpErr := NewHTTPError(http.StatusAccepted, "Bad request", nil, map[string]interface{}{ + "foo": "bar", + }) + + gotJson, err := json.Marshal(httpErr) + require.NoError(t, err) + + wantJson := `{ + "error": "Bad request", + "extras": { + "foo": "bar" + } + }` + require.JSONEq(t, wantJson, string(gotJson)) +} + +type testError struct { + Msg string +} + +func (te *testError) Error() string { + return te.Msg +} + +func TestError_unwrap(t *testing.T) { + wrappedError := testError{"wrapped error"} + httpErr := NewHTTPError(http.StatusForbidden, "Bad request", &wrappedError, map[string]interface{}{ + "foo": "bar", + }) + require.Equal(t, &wrappedError, httpErr.Unwrap()) + + require.True(t, errors.Is(httpErr, &wrappedError)) + + var e *testError + require.True(t, errors.As(httpErr, &e)) + require.Equal(t, &wrappedError, e) +} diff --git a/internal/serve/httphandler/assets_handler.go b/internal/serve/httphandler/assets_handler.go new file mode 100644 index 000000000..6e417085c --- /dev/null +++ b/internal/serve/httphandler/assets_handler.go @@ -0,0 +1,284 @@ +package httphandler + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/amount" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" +) + +const feeMultiplierInStroops = 10_000 + +var errCouldNotRemoveTrustline = errors.New("could not remove trustline") + +type AssetsHandler struct { + Models *data.Models + HorizonClient horizonclient.ClientInterface + SignatureService engine.SignatureService +} + +type AssetRequest struct { + Code string `json:"code"` + Issuer string `json:"issuer"` +} + +// GetAssets returns a list of assets. +func (c AssetsHandler) GetAssets(w http.ResponseWriter, r *http.Request) { + assets, err := c.Models.Assets.GetAll(r.Context()) + if err != nil { + ctx := r.Context() + httperror.InternalError(ctx, "Cannot retrieve assets", err, nil).Render(w) + return + } + httpjson.Render(w, assets, httpjson.JSON) +} + +// CreateAsset adds a new asset. +func (c AssetsHandler) CreateAsset(w http.ResponseWriter, r *http.Request) { + var assetRequest AssetRequest + err := json.NewDecoder(r.Body).Decode(&assetRequest) + if err != nil { + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + // TODO: add support for the Stellar Native Asset (XLM) + v := validators.NewValidator() + v.Check(assetRequest.Code != "", "code", "code is required") + v.Check(assetRequest.Issuer != "", "issuer", "issuer is required") + v.Check(strkey.IsValidEd25519PublicKey(assetRequest.Issuer), "issuer", "issuer is invalid") + + if v.HasErrors() { + httperror.BadRequest("Request invalid", err, v.Errors).Render(w) + return + } + + ctx := r.Context() + + asset, err := db.RunInTransactionWithResult(ctx, c.Models.DBConnectionPool, nil, func(dbTx db.DBTransaction) (*data.Asset, error) { + insertedAsset, insertErr := c.Models.Assets.Insert( + ctx, + dbTx, + assetRequest.Code, + assetRequest.Issuer, + ) + if insertErr != nil { + return nil, fmt.Errorf("error inserting new asset: %w", insertErr) + } + + trustlineErr := c.handleUpdateAssetTrustlineForDistributionAccount(ctx, &txnbuild.CreditAsset{ + Code: assetRequest.Code, + Issuer: assetRequest.Issuer, + }, nil) + if trustlineErr != nil { + return nil, fmt.Errorf("error adding trustline for the distribution account: %w", trustlineErr) + } + + return insertedAsset, nil + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httperror.Conflict("asset already exists", err, nil).Render(w) + return + } + + httperror.InternalError(ctx, "Cannot create new asset", err, nil).Render(w) + return + } + + httpjson.RenderStatus(w, http.StatusCreated, asset, httpjson.JSON) +} + +// DeleteAsset marks an asset for soft delete. +func (c AssetsHandler) DeleteAsset(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + assetID := chi.URLParam(r, "id") + + asset, err := c.Models.Assets.Get(ctx, assetID) + if err != nil { + log.Ctx(ctx).Errorf("Error performing soft delete on asset id %s: %s", assetID, err.Error()) + httperror.NotFound("could not find asset for deletion", err, nil).Render(w) + return + } + + if asset.DeletedAt != nil { + log.Ctx(ctx).Errorf("Error performing soft delete on asset id %s: %s", assetID, "asset already deleted") + httpjson.RenderStatus(w, http.StatusNoContent, "asset already deleted", httpjson.JSON) + return + } + + asset, err = db.RunInTransactionWithResult(ctx, c.Models.DBConnectionPool, nil, func(dbTx db.DBTransaction) (*data.Asset, error) { + deletedAsset, deleteErr := c.Models.Assets.SoftDelete(ctx, dbTx, assetID) + if deleteErr != nil { + return nil, fmt.Errorf("error performing soft delete on asset id %s: %w", assetID, deleteErr) + } + + trustlineErr := c.handleUpdateAssetTrustlineForDistributionAccount(ctx, nil, &txnbuild.CreditAsset{ + Code: deletedAsset.Code, + Issuer: deletedAsset.Issuer, + }) + if trustlineErr != nil { + return nil, fmt.Errorf("error removing trustline: %w", trustlineErr) + } + + return asset, nil + }) + if err != nil { + if errors.Is(err, errCouldNotRemoveTrustline) { + httperror.UnprocessableEntity("Could not remove trustline because distribution account still has balance", err, nil).Render(w) + return + } + + httperror.InternalError(ctx, "Cannot delete asset", err, nil).Render(w) + return + } + + httpjson.Render(w, asset, httpjson.JSON) +} + +func (c AssetsHandler) handleUpdateAssetTrustlineForDistributionAccount(ctx context.Context, assetToAddTrustline *txnbuild.CreditAsset, assetToRemoveTrustline *txnbuild.CreditAsset) error { + if assetToAddTrustline == nil && assetToRemoveTrustline == nil { + return fmt.Errorf("should provide at least one asset") + } + + if assetToAddTrustline != nil && assetToRemoveTrustline != nil && + *assetToAddTrustline == *assetToRemoveTrustline { + return fmt.Errorf("should provide different assets") + } + + acc, err := c.HorizonClient.AccountDetail(horizonclient.AccountRequest{ + AccountID: c.SignatureService.DistributionAccount(), + }) + if err != nil { + return fmt.Errorf("getting distribution account details: %w", err) + } + + changeTrustOperations := make([]*txnbuild.ChangeTrust, 0) + // remove asset + if assetToRemoveTrustline != nil { + for _, balance := range acc.Balances { + if balance.Asset.Code == assetToRemoveTrustline.Code && balance.Asset.Issuer == assetToRemoveTrustline.Issuer { + assetToRemoveTrustlineBalance, err := amount.ParseInt64(balance.Balance) + if err != nil { + return fmt.Errorf("converting asset to remove trustline balance to int64: %w", err) + } + if assetToRemoveTrustlineBalance > 0 { + log.Ctx(ctx).Warnf( + "not removing trustline for the asset %s:%s because the distribution account still has balance: %s %s", + assetToRemoveTrustline.Code, assetToRemoveTrustline.Issuer, + amount.StringFromInt64(assetToRemoveTrustlineBalance), assetToRemoveTrustline.Code, + ) + return errCouldNotRemoveTrustline + } + + log.Ctx(ctx).Infof("removing trustline for asset %s:%s", assetToRemoveTrustline.Code, assetToRemoveTrustline.Issuer) + changeTrustOperations = append(changeTrustOperations, &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: *assetToRemoveTrustline, + }, + Limit: "0", // 0 means remove trustline + SourceAccount: c.SignatureService.DistributionAccount(), + }) + + break + } + } + + if len(changeTrustOperations) == 0 { + log.Ctx(ctx).Warnf( + "not removing trustline for the asset %s:%s because it could not be found on the blockchain", + assetToRemoveTrustline.Code, assetToRemoveTrustline.Issuer, + ) + } + } + + // add asset + if assetToAddTrustline != nil { + var assetToAddTrustlineFound bool + for _, balance := range acc.Balances { + if balance.Asset.Code == assetToAddTrustline.Code && balance.Asset.Issuer == assetToAddTrustline.Issuer { + assetToAddTrustlineFound = true + log.Ctx(ctx).Warnf("not adding trustline for the asset %s:%s because it already exists", assetToAddTrustline.Code, assetToAddTrustline.Issuer) + break + } + } + + if !assetToAddTrustlineFound { + log.Ctx(ctx).Infof("adding trustline for asset %s:%s", assetToAddTrustline.Code, assetToAddTrustline.Issuer) + changeTrustOperations = append(changeTrustOperations, &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: *assetToAddTrustline, + }, + Limit: "", // empty means no limit + SourceAccount: c.SignatureService.DistributionAccount(), + }) + } + } + + // No operations to perform + if len(changeTrustOperations) == 0 { + log.Ctx(ctx).Warn("not performing either add or remove trustline") + return nil + } + + if err := c.submitChangeTrustTransaction(ctx, &acc, changeTrustOperations); err != nil { + return fmt.Errorf("submitting change trust transaction: %w", err) + } + + return nil +} + +func (c AssetsHandler) submitChangeTrustTransaction(ctx context.Context, acc *horizon.Account, changeTrustOperations []*txnbuild.ChangeTrust) error { + if len(changeTrustOperations) < 1 { + return fmt.Errorf("should have at least one change trust operation") + } + + operations := make([]txnbuild.Operation, 0, len(changeTrustOperations)) + for _, ctOp := range changeTrustOperations { + operations = append(operations, ctOp) + } + + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: c.SignatureService.DistributionAccount(), + Sequence: acc.Sequence, + }, + IncrementSequenceNum: true, + Operations: operations, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + if err != nil { + return fmt.Errorf("creating change trust transaction: %w", err) + } + + tx, err = c.SignatureService.SignStellarTransaction(ctx, tx, c.SignatureService.DistributionAccount()) + if err != nil { + return fmt.Errorf("signing change trust transaction: %w", err) + } + + _, err = c.HorizonClient.SubmitTransactionWithOptions(tx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}) + if err != nil { + return fmt.Errorf("submitting change trust transaction to network: %w", err) + } + + return nil +} diff --git a/internal/serve/httphandler/assets_handler_test.go b/internal/serve/httphandler/assets_handler_test.go new file mode 100644 index 000000000..d76a4a97d --- /dev/null +++ b/internal/serve/httphandler/assets_handler_test.go @@ -0,0 +1,1146 @@ +package httphandler + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/protocols/horizon/base" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/problem" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_AssetsHandlerGetAssets(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + handler := &AssetsHandler{ + Models: models, + } + + t.Run("successfully returns a list of assets", func(t *testing.T) { + expected := data.ClearAndCreateAssetFixtures(t, ctx, dbConnectionPool) + expectedJSON, err := json.Marshal(expected) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/assets", nil) + http.HandlerFunc(handler.GetAssets).ServeHTTP(rr, req) + + resp := rr.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + assert.JSONEq(t, string(expectedJSON), string(respBody)) + }) +} + +func Test_AssetHandlerAddAsset(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + model, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + distributionKP := keypair.MustRandom() + horizonClientMock := &horizonclient.MockClient{} + signatureService := mocks.NewMockSignatureService(t) + + handler := &AssetsHandler{ + Models: model, + SignatureService: signatureService, + HorizonClient: horizonClientMock, + } + + code := "USDT" + issuer := "GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ" + + signatureService. + On("DistributionAccount"). + Return(distributionKP.Address()). + Maybe() + + ctx := context.Background() + + t.Run("successfully create an asset", func(t *testing.T) { + getEntries := log.DefaultLogger.StartTest(log.InfoLevel) + + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", // no limit + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", mock.Anything, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{}, + }, nil). + Once(). + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + + rr := httptest.NewRecorder() + + requestBody, _ := json.Marshal(AssetRequest{code, issuer}) + + req, _ := http.NewRequest(http.MethodPost, "/assets", strings.NewReader(string(requestBody))) + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + entries := getEntries() + assert.Len(t, entries, 1) + assert.Equal(t, "adding trustline for asset USDT:GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ", entries[0].Message) + }) + + t.Run("successfully create an asset with a trustline already set", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + getEntries := log.DefaultLogger.StartTest(log.WarnLevel) + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Code: code, + Issuer: issuer, + }, + }, + }, + }, nil). + Once() + + rr := httptest.NewRecorder() + + requestBody, _ := json.Marshal(AssetRequest{code, issuer}) + + req, _ := http.NewRequest(http.MethodPost, "/assets", strings.NewReader(string(requestBody))) + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + entries := getEntries() + assert.Len(t, entries, 2) + assert.Equal(t, "not adding trustline for the asset USDT:GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ because it already exists", entries[0].Message) + }) + + t.Run("failed creating asset, issuer invalid", func(t *testing.T) { + rr := httptest.NewRecorder() + + requestBody, _ := json.Marshal(AssetRequest{code, "invalid"}) + + req, _ := http.NewRequest(http.MethodPost, "/assets", strings.NewReader(string(requestBody))) + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + + t.Run("failed creating asset, missing field", func(t *testing.T) { + rr := httptest.NewRecorder() + + requestBody, _ := json.Marshal(AssetRequest{}) + + req, _ := http.NewRequest(http.MethodPost, "/assets", strings.NewReader(string(requestBody))) + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + + t.Run("failed creating asset, empty fields", func(t *testing.T) { + rr := httptest.NewRecorder() + + emptyStr := "" + requestBody, _ := json.Marshal(AssetRequest{Code: emptyStr, Issuer: emptyStr}) + + req, _ := http.NewRequest(http.MethodPost, "/assets", strings.NewReader(string(requestBody))) + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + + t.Run("failed creating asset, duplicated asset", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", // no limit + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", mock.Anything, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{}, + }, nil). + Once(). + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + + // Creating the asset + requestBody, err := json.Marshal(AssetRequest{Code: code, Issuer: issuer}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/assets", bytes.NewReader(requestBody)) + require.NoError(t, err) + + rr := httptest.NewRecorder() + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + // Duplicating the asset + requestBody, err = json.Marshal(AssetRequest{Code: code, Issuer: issuer}) + require.NoError(t, err) + + req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/assets", bytes.NewReader(requestBody)) + require.NoError(t, err) + + rr = httptest.NewRecorder() + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp = rr.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusConflict, resp.StatusCode) + assert.JSONEq(t, `{"error": "asset already exists"}`, string(respBody)) + }) + + t.Run("failed creating asset, error adding asset trustline", func(t *testing.T) { + ctx := context.Background() + + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", // no limit + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", mock.Anything, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{}, + }, nil). + Once(). + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, horizonclient.Error{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + Problem: problem.P{ + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_no_issuer"}, + }, + }, + }, + }). + Once() + + // Creating the asset + requestBody, err := json.Marshal(AssetRequest{Code: code, Issuer: issuer}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/assets", bytes.NewReader(requestBody)) + require.NoError(t, err) + + rr := httptest.NewRecorder() + http.HandlerFunc(handler.CreateAsset).ServeHTTP(rr, req) + + resp := rr.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot create new asset"}`, string(respBody)) + }) + + horizonClientMock.AssertExpectations(t) + signatureService.AssertExpectations(t) +} + +func Test_AssetHandlerDeleteAsset(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + model, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + distributionKP := keypair.MustRandom() + horizonClientMock := &horizonclient.MockClient{} + signatureService := mocks.NewMockSignatureService(t) + + handler := &AssetsHandler{ + Models: model, + SignatureService: signatureService, + HorizonClient: horizonClientMock, + } + + r := chi.NewRouter() + r.Delete("/assets/{id}", handler.DeleteAsset) + + signatureService. + On("DistributionAccount"). + Return(distributionKP.Address()). + Maybe() + + t.Run("successfully delete an asset and remove the trustline", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "ABC", "GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ") + + getEntries := log.DefaultLogger.StartTest(log.InfoLevel) + + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: asset.Code, + Issuer: asset.Issuer, + }, + }, + Limit: "0", + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", mock.Anything, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "0", + Asset: base.Asset{ + Code: asset.Code, + Issuer: asset.Issuer, + }, + }, + }, + }, nil). + Once(). + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + + rr := httptest.NewRecorder() + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/assets/%s", asset.ID), nil) + require.NoError(t, err) + r.ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + assetDB, err := model.Assets.Get(ctx, asset.ID) + require.NoError(t, err) + assert.NotNil(t, assetDB.DeletedAt) + + entries := getEntries() + assert.Len(t, entries, 1) + assert.Equal(t, "removing trustline for asset ABC:GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ", entries[0].Message) + }) + + // We decided to not have a mismatch between the Network and the Database. So, if the trustline is not removed, + // the asset won't be deleted as well. + t.Run("doesn't remove the asset when couldn't remove the trustline", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "ABC", "GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ") + + getEntries := log.DefaultLogger.StartTest(log.WarnLevel) + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Code: asset.Code, + Issuer: asset.Issuer, + }, + }, + }, + }, nil). + Once() + + rr := httptest.NewRecorder() + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/assets/%s", asset.ID), nil) + require.NoError(t, err) + r.ServeHTTP(rr, req) + + resp := rr.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) + assert.JSONEq(t, `{"error": "Could not remove trustline because distribution account still has balance"}`, string(respBody)) + + // Asset should not be soft deleted. + assetDB, err := model.Assets.Get(ctx, asset.ID) + require.NoError(t, err) + assert.Nil(t, assetDB.DeletedAt) + + entries := getEntries() + assert.Len(t, entries, 2) + assert.Equal(t, "not removing trustline for the asset ABC:GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ because the distribution account still has balance: 100.0000000 ABC", entries[0].Message) + }) + + t.Run("returns error when an error occurs removing trustline", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "ABC", "GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ") + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{}, horizonclient.Error{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + Problem: problem.P{ + Title: "Error occurred", + Status: http.StatusBadRequest, + }, + }). + Once() + + rr := httptest.NewRecorder() + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/assets/%s", asset.ID), nil) + require.NoError(t, err) + r.ServeHTTP(rr, req) + + resp := rr.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error":"Cannot delete asset"}`, string(respBody)) + + // Asset should not be soft deleted. + assetDB, err := model.Assets.Get(ctx, asset.ID) + require.NoError(t, err) + assert.Nil(t, assetDB.DeletedAt) + }) + + t.Run("failed deleting an asset, asset not found", func(t *testing.T) { + rr := httptest.NewRecorder() + + req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("/assets/%s", "nonexistant"), nil) + r.ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + horizonClientMock.AssertExpectations(t) + signatureService.AssertExpectations(t) +} + +func Test_AssetHandler_handleUpdateAssetTrustlineForDistributionAccount(t *testing.T) { + distributionKP := keypair.MustRandom() + horizonClientMock := &horizonclient.MockClient{} + signatureService := mocks.NewMockSignatureService(t) + + handler := &AssetsHandler{ + SignatureService: signatureService, + HorizonClient: horizonClientMock, + } + + assetToAddTrustline := &txnbuild.CreditAsset{ + Code: "USDC", + Issuer: "GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ", + } + + assetToRemoveTrustline := &txnbuild.CreditAsset{ + Code: "USDT", + Issuer: "GA24LJXFG73JGARIBG2GP6V5TNUUOS6BD23KOFCW3INLDY5KPKS7GACZ", + } + + ctx := context.Background() + + signatureService. + On("DistributionAccount"). + Return(distributionKP.Address()). + Maybe() + + t.Run("returns error if no asset is provided", func(t *testing.T) { + err := handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, nil, nil) + assert.EqualError(t, err, "should provide at least one asset") + }) + + t.Run("returns error if the assets are the same", func(t *testing.T) { + err := handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, assetToRemoveTrustline, assetToRemoveTrustline) + assert.EqualError(t, err, "should provide different assets") + }) + + t.Run("returns error if fails getting distribution account details", func(t *testing.T) { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{}, horizonclient.Error{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + Problem: problem.P{ + Title: "Error occurred", + Status: http.StatusBadRequest, + }, + }). + Once() + + err := handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, assetToAddTrustline, assetToRemoveTrustline) + assert.EqualError(t, err, "getting distribution account details: horizon error: \"Error occurred\" - check horizon.Error.Problem for more information") + }) + + t.Run("returns error if fails submitting change trust transaction", func(t *testing.T) { + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: assetToRemoveTrustline.Code, + Issuer: assetToRemoveTrustline.Issuer, + }, + }, + Limit: "0", + SourceAccount: distributionKP.Address(), + }, + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: assetToAddTrustline.Code, + Issuer: assetToAddTrustline.Issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", ctx, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: "XLM", + Issuer: "", + }, + }, + { + Balance: "0", + Asset: base.Asset{ + Type: "", + Code: assetToRemoveTrustline.Code, + Issuer: assetToRemoveTrustline.Issuer, + }, + }, + }, + }, nil). + Once(). + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, horizonclient.Error{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + Problem: problem.P{ + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_no_issuer"}, + }, + }, + }, + }). + Once() + + err = handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, assetToAddTrustline, assetToRemoveTrustline) + assert.EqualError(t, err, "submitting change trust transaction: submitting change trust transaction to network: horizon error: \"\" (tx_failed, op_no_issuer) - check horizon.Error.Problem for more information") + }) + + t.Run("adds and removes the trustlines successfully", func(t *testing.T) { + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: assetToRemoveTrustline.Code, + Issuer: assetToRemoveTrustline.Issuer, + }, + }, + Limit: "0", + SourceAccount: distributionKP.Address(), + }, + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: assetToAddTrustline.Code, + Issuer: assetToAddTrustline.Issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", ctx, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: "XLM", + Issuer: "", + }, + }, + { + Balance: "0", + Asset: base.Asset{ + Type: "", + Code: assetToRemoveTrustline.Code, + Issuer: assetToRemoveTrustline.Issuer, + }, + }, + }, + }, nil). + Once(). + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + + err = handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, assetToAddTrustline, assetToRemoveTrustline) + assert.NoError(t, err) + }) + + t.Run("doesn't remove the trustline in case still has balance", func(t *testing.T) { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: "XLM", + Issuer: "", + }, + }, + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: assetToRemoveTrustline.Code, + Issuer: assetToRemoveTrustline.Issuer, + }, + }, + }, + }, nil). + Once() + + err := handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, assetToAddTrustline, assetToRemoveTrustline) + assert.EqualError(t, err, errCouldNotRemoveTrustline.Error()) + }) + + t.Run("doesn't add new trustline if distribution account already have trustline for the asset", func(t *testing.T) { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{ + AccountID: distributionKP.Address(), + }). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: "XLM", + Issuer: "", + }, + }, + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: assetToAddTrustline.Code, + Issuer: assetToAddTrustline.Issuer, + }, + }, + }, + }, nil). + Once() + + err := handler.handleUpdateAssetTrustlineForDistributionAccount(ctx, assetToAddTrustline, nil) + assert.NoError(t, err) + }) + + horizonClientMock.AssertExpectations(t) + signatureService.AssertExpectations(t) +} + +func Test_AssetHandler_submitChangeTrustTransaction(t *testing.T) { + distributionKP := keypair.MustRandom() + horizonClientMock := &horizonclient.MockClient{} + signatureService := mocks.NewMockSignatureService(t) + + handler := &AssetsHandler{ + SignatureService: signatureService, + HorizonClient: horizonClientMock, + } + + code := "USDC" + issuer := "GBHC5ADV2XYITXCYC5F6X6BM2OYTYHV4ZU2JF6QWJORJQE2O7RKH2LAQ" + + acc := &horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 123, + Balances: []horizon.Balance{ + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: "XLM", + Issuer: "", + }, + }, + { + Balance: "100", + Asset: base.Asset{ + Type: "", + Code: code, + Issuer: issuer, + }, + }, + }, + } + + ctx := context.Background() + + signatureService. + On("DistributionAccount"). + Return(distributionKP.Address()) + + t.Run("returns error if no change trust operations is passed", func(t *testing.T) { + err := handler.submitChangeTrustTransaction(ctx, acc, []*txnbuild.ChangeTrust{}) + assert.EqualError(t, err, "should have at least one change trust operation") + }) + + t.Run("returns error when fails signing transaction", func(t *testing.T) { + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", ctx, tx, distributionKP.Address()). + Return(nil, errors.New("unexpected error")). + Once() + + err = handler.submitChangeTrustTransaction(ctx, acc, []*txnbuild.ChangeTrust{ + { + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }) + assert.EqualError(t, err, "signing change trust transaction: unexpected error") + }) + + t.Run("returns error if fails submitting change trust transaction", func(t *testing.T) { + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", ctx, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, horizonclient.Error{ + Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }, + Problem: problem.P{ + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_no_issuer"}, + }, + }, + }, + }). + Once() + + err = handler.submitChangeTrustTransaction(ctx, acc, []*txnbuild.ChangeTrust{ + { + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }) + assert.EqualError(t, err, "submitting change trust transaction to network: horizon error: \"\" (tx_failed, op_no_issuer) - check horizon.Error.Problem for more information") + }) + + t.Run("submits transaction correctly", func(t *testing.T) { + tx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: distributionKP.Address(), + Sequence: 124, + }, + IncrementSequenceNum: false, + Operations: []txnbuild.Operation{ + &txnbuild.ChangeTrust{ + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }, + BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(20)}, + }, + ) + require.NoError(t, err) + + signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) + require.NoError(t, err) + + signatureService. + On("SignStellarTransaction", ctx, tx, distributionKP.Address()). + Return(signedTx, nil). + Once() + + horizonClientMock. + On("SubmitTransactionWithOptions", signedTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + + err = handler.submitChangeTrustTransaction(ctx, acc, []*txnbuild.ChangeTrust{ + { + Line: txnbuild.ChangeTrustAssetWrapper{ + Asset: txnbuild.CreditAsset{ + Code: code, + Issuer: issuer, + }, + }, + Limit: "", + SourceAccount: distributionKP.Address(), + }, + }) + assert.NoError(t, err) + }) + + horizonClientMock.AssertExpectations(t) + signatureService.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/countries_handler.go b/internal/serve/httphandler/countries_handler.go new file mode 100644 index 000000000..413294df2 --- /dev/null +++ b/internal/serve/httphandler/countries_handler.go @@ -0,0 +1,24 @@ +package httphandler + +import ( + "net/http" + + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" +) + +type CountriesHandler struct { + Models *data.Models +} + +// GetCountries returns a list of countries +func (c CountriesHandler) GetCountries(w http.ResponseWriter, r *http.Request) { + countries, err := c.Models.Countries.GetAll(r.Context()) + if err != nil { + ctx := r.Context() + httperror.InternalError(ctx, "Cannot retrieve countries", err, nil).Render(w) + return + } + httpjson.Render(w, countries, httpjson.JSON) +} diff --git a/internal/serve/httphandler/countries_handler_test.go b/internal/serve/httphandler/countries_handler_test.go new file mode 100644 index 000000000..a7cb668fe --- /dev/null +++ b/internal/serve/httphandler/countries_handler_test.go @@ -0,0 +1,53 @@ +package httphandler + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CountriesHandlerGetCountries(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + handler := &CountriesHandler{ + Models: models, + } + + t.Run("successfully returns a list of countries", func(t *testing.T) { + expected := data.ClearAndCreateCountryFixtures(t, ctx, dbConnectionPool) + expectedJSON, err := json.Marshal(expected) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/countries", nil) + http.HandlerFunc(handler.GetCountries).ServeHTTP(rr, req) + + resp := rr.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + assert.JSONEq(t, string(expectedJSON), string(respBody)) + }) +} diff --git a/internal/serve/httphandler/delete_phone_number_handler.go b/internal/serve/httphandler/delete_phone_number_handler.go new file mode 100644 index 000000000..118d21430 --- /dev/null +++ b/internal/serve/httphandler/delete_phone_number_handler.go @@ -0,0 +1,47 @@ +package httphandler + +import ( + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/network" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type DeletePhoneNumberHandler struct { + NetworkPassphrase string + Models *data.Models +} + +func (d DeletePhoneNumberHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if d.NetworkPassphrase != network.TestNetworkPassphrase { + httperror.NotFound("", nil, nil).Render(w) + return + } + + phoneNumber := chi.URLParam(r, "phone_number") + if err := utils.ValidatePhoneNumber(phoneNumber); err != nil { + extras := map[string]interface{}{"phone_number": "invalid phone number"} + httperror.BadRequest("", nil, extras).Render(w) + return + } + + log.Ctx(ctx).Warnf("Deleting user with phone number %s", utils.TruncateString(phoneNumber, 3)) + err := d.Models.Receiver.DeleteByPhoneNumber(ctx, d.Models.DBConnectionPool, phoneNumber) + if err != nil { + if errors.Is(err, data.ErrRecordNotFound) { + httperror.NotFound("", err, nil).Render(w) + } else { + httperror.InternalError(ctx, "Cannot delete phone number", err, nil).Render(w) + } + return + } + + httpjson.RenderStatus(w, http.StatusNoContent, nil, httpjson.JSON) +} diff --git a/internal/serve/httphandler/delete_phone_number_handler_test.go b/internal/serve/httphandler/delete_phone_number_handler_test.go new file mode 100644 index 000000000..27591ee76 --- /dev/null +++ b/internal/serve/httphandler/delete_phone_number_handler_test.go @@ -0,0 +1,102 @@ +package httphandler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/network" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DeletePhoneNumberHandler(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{PhoneNumber: "+14152222222"}) + + t.Run("return 404 if network passphrase is not testnet", func(t *testing.T) { + h := DeletePhoneNumberHandler{NetworkPassphrase: network.PublicNetworkPassphrase, Models: models} + r := chi.NewRouter() + r.Delete("/wallet-registration/phone-number/{phone_number}", h.ServeHTTP) + + // test + req, err := http.NewRequest("DELETE", "/wallet-registration/phone-number/"+receiver.PhoneNumber, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNotFound, rr.Code) + wantJson := `{"error": "Resource not found."}` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("return 400 if network passphrase is testnet but phone number is invalid", func(t *testing.T) { + h := DeletePhoneNumberHandler{NetworkPassphrase: network.TestNetworkPassphrase, Models: models} + r := chi.NewRouter() + r.Delete("/wallet-registration/phone-number/{phone_number}", h.ServeHTTP) + + // test + req, err := http.NewRequest("DELETE", "/wallet-registration/phone-number/foobar", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusBadRequest, rr.Code) + wantJson := `{ + "error": "The request was invalid in some way.", + "extras": { + "phone_number": "invalid phone number" + } + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("return 404 if network passphrase is testnet but phone number does not exist", func(t *testing.T) { + h := DeletePhoneNumberHandler{NetworkPassphrase: network.TestNetworkPassphrase, Models: models} + r := chi.NewRouter() + r.Delete("/wallet-registration/phone-number/{phone_number}", h.ServeHTTP) + + // test + req, err := http.NewRequest("DELETE", "/wallet-registration/phone-number/+14153333333", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNotFound, rr.Code) + wantJson := `{"error":"Resource not found."}` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("return 204 if network passphrase is testnet and phone nymber exists", func(t *testing.T) { + h := DeletePhoneNumberHandler{NetworkPassphrase: network.TestNetworkPassphrase, Models: models} + r := chi.NewRouter() + r.Delete("/wallet-registration/phone-number/{phone_number}", h.ServeHTTP) + + // test + req, err := http.NewRequest("DELETE", "/wallet-registration/phone-number/"+receiver.PhoneNumber, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNoContent, rr.Code) + wantJson := "null" + assert.JSONEq(t, wantJson, rr.Body.String()) + }) +} diff --git a/internal/serve/httphandler/disbursement_handler.go b/internal/serve/httphandler/disbursement_handler.go new file mode 100644 index 000000000..0f69f801b --- /dev/null +++ b/internal/serve/httphandler/disbursement_handler.go @@ -0,0 +1,419 @@ +package httphandler + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + + "github.com/gocarina/gocsv" + + "github.com/go-chi/chi/v5" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpresponse" + + "github.com/stellar/go/support/log" + + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" +) + +type DisbursementHandler struct { + Models *data.Models + MonitorService monitor.MonitorServiceInterface + DBConnectionPool db.DBConnectionPool + AuthManager auth.AuthManager +} + +type PostDisbursementRequest struct { + Name string `json:"name"` + CountryCode string `json:"country_code"` + WalletID string `json:"wallet_id"` + AssetID string `json:"asset_id"` +} + +type PatchDisbursementStatusRequest struct { + Status string `json:"status"` +} + +func (d DisbursementHandler) PostDisbursement(w http.ResponseWriter, r *http.Request) { + var disbursementRequest PostDisbursementRequest + + err := json.NewDecoder(r.Body).Decode(&disbursementRequest) + if err != nil { + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + // validate request + v := validators.NewValidator() + + v.Check(disbursementRequest.Name != "", "name", "name is required") + v.Check(disbursementRequest.CountryCode != "", "country_code", "country_code is required") + v.Check(disbursementRequest.WalletID != "", "wallet_id", "wallet_id is required") + v.Check(disbursementRequest.AssetID != "", "asset_id", "asset_id is required") + + if v.HasErrors() { + httperror.BadRequest("Request invalid", err, v.Errors).Render(w) + return + } + + ctx := r.Context() + wallet, err := d.Models.Wallets.Get(ctx, disbursementRequest.WalletID) + if err != nil { + httperror.BadRequest("wallet ID is invalid", err, nil).Render(w) + return + } + asset, err := d.Models.Assets.Get(ctx, disbursementRequest.AssetID) + if err != nil { + httperror.BadRequest("asset ID is invalid", err, nil).Render(w) + return + } + country, err := d.Models.Countries.Get(ctx, disbursementRequest.CountryCode) + if err != nil { + httperror.BadRequest("country code is invalid", err, nil).Render(w) + return + } + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + msg := fmt.Sprintf("Cannot get token from context when inserting disbursement %s", disbursementRequest.Name) + httperror.InternalError(ctx, msg, nil, nil).Render(w) + return + } + user, err := d.AuthManager.GetUser(ctx, token) + if err != nil { + msg := fmt.Sprintf("Cannot insert disbursement %s", disbursementRequest.Name) + httperror.InternalError(ctx, msg, err, nil).Render(w) + return + } + + disbursement := data.Disbursement{ + Name: disbursementRequest.Name, + Status: data.DraftDisbursementStatus, + StatusHistory: []data.DisbursementStatusHistoryEntry{{ + Timestamp: time.Now(), + Status: data.DraftDisbursementStatus, + UserID: user.ID, + }}, + Wallet: wallet, + Asset: asset, + Country: country, + } + + newId, err := d.Models.Disbursements.Insert(ctx, &disbursement) + if err != nil { + if errors.Is(data.ErrRecordAlreadyExists, err) { + httperror.Conflict("disbursement already exists", err, nil).Render(w) + } else { + httperror.BadRequest("could not create disbursement", err, nil).Render(w) + } + return + } + + newDisbursement, err := d.Models.Disbursements.Get(ctx, d.DBConnectionPool, newId) + if err != nil { + msg := fmt.Sprintf("Cannot retrieve disbursement for ID: %s", newId) + httperror.InternalError(ctx, msg, err, nil).Render(w) + return + } + + labels := monitor.DisbursementLabels{ + Asset: newDisbursement.Asset.Code, + Country: newDisbursement.Country.Code, + Wallet: newDisbursement.Wallet.Name, + } + + err = d.MonitorService.MonitorCounters(monitor.DisbursementsCounterTag, labels.ToMap()) + if err != nil { + log.Ctx(ctx).Errorf("Error trying to monitor disbursement counter: %s", err) + } + + httpjson.RenderStatus(w, http.StatusCreated, newDisbursement, httpjson.JSON) +} + +// GetDisbursements returns a paginated list of disbursements +func (d DisbursementHandler) GetDisbursements(w http.ResponseWriter, r *http.Request) { + validator := validators.NewDisbursementQueryValidator() + queryParams := validator.ParseParametersFromRequest(r) + + if validator.HasErrors() { + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + queryParams.Filters = validator.ValidateAndGetDisbursementFilters(queryParams.Filters) + if validator.HasErrors() { + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + ctx := r.Context() + disbursementService := services.NewDisbursementService(d.Models, d.DBConnectionPool, d.AuthManager) + resultWithTotal, err := disbursementService.GetDisbursementsWithCount(ctx, queryParams) + if err != nil { + httperror.InternalError(ctx, "Cannot retrieve disbursements", err, nil).Render(w) + return + } + if resultWithTotal.Total == 0 { + httpjson.RenderStatus(w, http.StatusOK, httpresponse.NewEmptyPaginatedResponse(), httpjson.JSON) + } else { + response, errGet := httpresponse.NewPaginatedResponse(r, resultWithTotal.Result, queryParams.Page, queryParams.PageLimit, resultWithTotal.Total) + if errGet != nil { + httperror.InternalError(ctx, "Cannot write paginated response for disbursements", errGet, nil).Render(w) + return + } + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) + } +} + +func (d DisbursementHandler) PostDisbursementInstructions(w http.ResponseWriter, r *http.Request) { + disbursementID := chi.URLParam(r, "id") + + // check if disbursement exists + ctx := r.Context() + disbursement, err := d.Models.Disbursements.Get(ctx, d.DBConnectionPool, disbursementID) + if err != nil { + httperror.BadRequest("disbursement ID is invalid", err, nil).Render(w) + return + } + + // check if disbursement is in draft, ready status + if disbursement.Status != data.DraftDisbursementStatus && disbursement.Status != data.ReadyDisbursementStatus { + httperror.BadRequest("disbursement is not in draft or ready status", nil, nil).Render(w) + return + } + + // Parse uploaded CSV file + file, header, err := r.FormFile("file") + if err != nil { + httperror.BadRequest("could not parse file", err, nil).Render(w) + return + } + defer file.Close() + + // TeeReader is used to read multiple times from the same reader (file) + // We read once to process the instructions, and then again to persist the file to the database + var buf bytes.Buffer + reader := io.TeeReader(file, &buf) + + instructions, v := parseInstructionsFromCSV(reader, disbursement.VerificationField) + if v != nil && v.HasErrors() { + httperror.BadRequest("could not parse csv file", err, v.Errors).Render(w) + return + } + + disbursementUpdate := &data.DisbursementUpdate{ + ID: disbursementID, + FileName: header.Filename, + FileContent: buf.Bytes(), + } + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + msg := fmt.Sprintf("Cannot get token from context when processing instructions for disbursement with ID %s", disbursementID) + httperror.InternalError(ctx, msg, err, nil).Render(w) + return + } + user, err := d.AuthManager.GetUser(ctx, token) + if err != nil { + msg := fmt.Sprintf("Cannot get token from context when processing instructions for disbursement with ID %s", disbursementID) + httperror.InternalError(ctx, msg, err, nil).Render(w) + return + } + + if err = d.Models.DisbursementInstructions.ProcessAll(ctx, user.ID, instructions, disbursement, disbursementUpdate, data.MaxInstructionsPerDisbursement); err != nil { + switch { + case errors.Is(err, data.ErrMaxInstructionsExceeded): + httperror.BadRequest(fmt.Sprintf("number of instructions exceeds maximum of : %d", data.MaxInstructionsPerDisbursement), err, nil).Render(w) + case errors.Is(err, data.ErrReceiverVerificationMismatch): + httperror.BadRequest(errors.Unwrap(err).Error(), err, nil).Render(w) + default: + httperror.InternalError(ctx, fmt.Sprintf("Cannot process instructions for disbursement with ID: %s", disbursementID), err, nil).Render(w) + } + return + } + + response := map[string]string{ + "message": "File uploaded successfully", + } + + httpjson.Render(w, response, httpjson.JSON) +} + +func (d DisbursementHandler) GetDisbursement(w http.ResponseWriter, r *http.Request) { + disbursementID := chi.URLParam(r, "id") + + ctx := r.Context() + disbursement, err := d.Models.Disbursements.GetWithStatistics(ctx, disbursementID) + if err != nil { + if errors.Is(err, data.ErrRecordNotFound) { + httperror.NotFound("disbursement not found", err, nil).Render(w) + } else { + msg := fmt.Sprintf("Cannot get receivers for disbursement with ID: %s", disbursementID) + httperror.InternalError(ctx, msg, err, nil).Render(w) + } + return + } + + httpjson.Render(w, disbursement, httpjson.JSON) +} + +func (d DisbursementHandler) GetDisbursementReceivers(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + disbursementID := chi.URLParam(r, "id") + + validator := validators.NewReceiverQueryValidator() + queryParams := validator.ParseParametersFromRequest(r) + + if validator.HasErrors() { + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + disbursementService := services.NewDisbursementService(d.Models, d.DBConnectionPool, d.AuthManager) + resultWithTotal, err := disbursementService.GetDisbursementReceiversWithCount(ctx, disbursementID, queryParams) + if err != nil { + if errors.Is(err, services.ErrDisbursementNotFound) { + httperror.NotFound("disbursement not found", err, nil).Render(w) + } else { + msg := fmt.Sprintf("Cannot find disbursement with ID: %s", disbursementID) + httperror.InternalError(ctx, msg, err, nil).Render(w) + } + return + } + + if resultWithTotal.Total == 0 { + httpjson.RenderStatus(w, http.StatusOK, httpresponse.NewEmptyPaginatedResponse(), httpjson.JSON) + return + } + + response, err := httpresponse.NewPaginatedResponse(r, resultWithTotal.Result, queryParams.Page, queryParams.PageLimit, resultWithTotal.Total) + if err != nil { + msg := fmt.Sprintf("Cannot write paginated response for disbursement with ID: %s", disbursementID) + httperror.InternalError(ctx, msg, err, nil).Render(w) + return + } + + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) +} + +type UpdateDisbursementStatusResponseBody struct { + Message string `json:"message"` +} + +// PatchDisbursementStatus updates the status of a disbursement +func (d DisbursementHandler) PatchDisbursementStatus(w http.ResponseWriter, r *http.Request) { + var patchRequest PatchDisbursementStatusRequest + err := json.NewDecoder(r.Body).Decode(&patchRequest) + if err != nil { + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + // validate request + toStatus, err := data.ToDisbursementStatus(patchRequest.Status) + if err != nil { + httperror.BadRequest("invalid status", err, nil).Render(w) + return + } + + disbursementService := services.NewDisbursementService(d.Models, d.DBConnectionPool, d.AuthManager) + response := UpdateDisbursementStatusResponseBody{} + + ctx := r.Context() + disbursementID := chi.URLParam(r, "id") + switch toStatus { + case data.StartedDisbursementStatus: + err = disbursementService.StartDisbursement(ctx, disbursementID) + response.Message = "Disbursement started" + case data.PausedDisbursementStatus: + err = disbursementService.PauseDisbursement(ctx, disbursementID) + response.Message = "Disbursement paused" + default: + err = services.ErrDisbursementStatusCantBeChanged + } + + if err != nil { + switch { + case errors.Is(err, services.ErrDisbursementNotFound): + httperror.NotFound(services.ErrDisbursementNotFound.Error(), err, nil).Render(w) + case errors.Is(err, services.ErrDisbursementNotReadyToStart): + httperror.BadRequest(services.ErrDisbursementNotReadyToStart.Error(), err, nil).Render(w) + case errors.Is(err, services.ErrDisbursementNotReadyToPause): + httperror.BadRequest(services.ErrDisbursementNotReadyToPause.Error(), err, nil).Render(w) + case errors.Is(err, services.ErrDisbursementStatusCantBeChanged): + httperror.BadRequest(services.ErrDisbursementStatusCantBeChanged.Error(), err, nil).Render(w) + case errors.Is(err, services.ErrDisbursementStartedByCreator): + httperror.Forbidden("Disbursement can't be started by its creator. Approval by another user is required.", err, nil).Render(w) + default: + msg := fmt.Sprintf("Cannot update disbursement ID %s with status: %s", disbursementID, toStatus) + httperror.InternalError(ctx, msg, err, nil).Render(w) + } + return + } + + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) +} + +func (d DisbursementHandler) GetDisbursementInstructions(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + disbursementID := chi.URLParam(r, "id") + + disbursement, err := d.Models.Disbursements.Get(ctx, d.DBConnectionPool, disbursementID) + if err != nil { + httperror.NotFound("disbursement not found", err, nil).Render(w) + return + } + + if len(disbursement.FileContent) == 0 { + err = fmt.Errorf("disbursement %s has no instructions file", disbursementID) + log.Ctx(ctx).Error(err) + httperror.NotFound(err.Error(), err, nil).Render(w) + return + } + + // `attachment` returns a file-download prompt. change that to `inline` to open in browser + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, disbursement.FileName)) + w.Header().Set("Content-Type", "text/csv") + _, err = w.Write(disbursement.FileContent) + if err != nil { + httperror.InternalError(ctx, "Cannot write disbursement instructions to response", err, nil).Render(w) + } +} + +func parseInstructionsFromCSV(file io.Reader, verificationField data.VerificationField) ([]*data.DisbursementInstruction, *validators.DisbursementInstructionsValidator) { + validator := validators.NewDisbursementInstructionsValidator(verificationField) + + instructions := []*data.DisbursementInstruction{} + if err := gocsv.Unmarshal(file, &instructions); err != nil { + log.Errorf("error parsing csv file: %s", err.Error()) + validator.Errors["file"] = "could not parse file" + return nil, validator + } + + for i, instruction := range instructions { + lineNumber := i + 2 // +1 for header row, +1 for 0-index + validator.ValidateInstruction(instruction, lineNumber) + } + + validator.Check(len(instructions) > 0, "instructions", "no valid instructions found") + + if validator.HasErrors() { + return nil, validator + } + + return instructions, nil +} diff --git a/internal/serve/httphandler/disbursement_handler_test.go b/internal/serve/httphandler/disbursement_handler_test.go new file mode 100644 index 000000000..22b87e1d6 --- /dev/null +++ b/internal/serve/httphandler/disbursement_handler_test.go @@ -0,0 +1,1380 @@ +package httphandler + +import ( + "bytes" + "context" + "encoding/csv" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + + "github.com/go-chi/chi/v5" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpresponse" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" +) + +func Test_DisbursementHandler_PostDisbursement(t *testing.T) { + const url = "/disbursements" + const method = "POST" + + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + token := "token" + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + user := &auth.User{ + ID: "user-id", + Email: "email@email.com", + } + authManagerMock := &auth.AuthManagerMock{} + authManagerMock. + On("GetUser", mock.Anything, token). + Return(user, nil) + + mMonitorService := &monitor.MockMonitorService{} + + handler := &DisbursementHandler{ + Models: models, + MonitorService: mMonitorService, + DBConnectionPool: models.DBConnectionPool, + AuthManager: authManagerMock, + } + + // setup fixtures + wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool) + asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC) + country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUKR) + + t.Run("returns error when body is invalid", func(t *testing.T) { + requestBody := ` + { + "name": "My New Disbursement name 5", + }` + + want := `{"error":"invalid request body"}` + + assertPOSTResponse(t, ctx, handler, method, url, requestBody, want, http.StatusBadRequest) + }) + + t.Run("returns error when name is not provided", func(t *testing.T) { + requestBody := ` + { + "wallet_id": "aab4a4a9-2493-4f37-9741-01d5bd31d68b", + "asset_id": "61dbfa89-943a-413c-b862-a2177384d321", + "country_code": "UKR" + }` + + want := ` + { + "error":"Request invalid", + "extras": { + "name": "name is required" + } + }` + + assertPOSTResponse(t, ctx, handler, method, url, requestBody, want, http.StatusBadRequest) + }) + + t.Run("returns error when wallet_id is not provided", func(t *testing.T) { + requestBody := ` + { + "name": "My New Disbursement name 5", + "asset_id": "61dbfa89-943a-413c-b862-a2177384d321", + "country_code": "UKR" + }` + + want := `{"error":"Request invalid", "extras": {"wallet_id": "wallet_id is required"}}` + + assertPOSTResponse(t, ctx, handler, method, url, requestBody, want, http.StatusBadRequest) + }) + + t.Run("returns error when asset_id is not provided", func(t *testing.T) { + requestBody := ` + { + "name": "My New Disbursement name 5", + "wallet_id": "aab4a4a9-2493-4f37-9741-01d5bd31d68b", + "country_code": "UKR" + }` + + want := `{"error":"Request invalid", "extras": {"asset_id": "asset_id is required"}}` + + assertPOSTResponse(t, ctx, handler, method, url, requestBody, want, http.StatusBadRequest) + }) + + t.Run("returns error when country_code is not provided", func(t *testing.T) { + requestBody := ` + { + "name": "My New Disbursement name 5", + "wallet_id": "aab4a4a9-2493-4f37-9741-01d5bd31d68b", + "asset_id": "61dbfa89-943a-413c-b862-a2177384d321" + }` + + want := `{"error":"Request invalid", "extras": {"country_code": "country_code is required"}}` + + assertPOSTResponse(t, ctx, handler, method, url, requestBody, want, http.StatusBadRequest) + }) + + t.Run("returns error when wallet_id is not valid", func(t *testing.T) { + requestBody, err := json.Marshal(PostDisbursementRequest{ + Name: "disbursement 1", + CountryCode: country.Code, + AssetID: asset.ID, + WalletID: "aab4a4a9-2493-4f37-9741-01d5bd31d68b", + }) + require.NoError(t, err) + + want := `{"error":"wallet ID is invalid"}` + + assertPOSTResponse(t, ctx, handler, method, url, string(requestBody), want, http.StatusBadRequest) + }) + + t.Run("returns error when asset_id is not valid", func(t *testing.T) { + requestBody, err := json.Marshal(PostDisbursementRequest{ + Name: "disbursement 1", + CountryCode: country.Code, + AssetID: "aab4a4a9-2493-4f37-9741-01d5bd31d68b", + WalletID: wallet.ID, + }) + require.NoError(t, err) + + want := `{"error":"asset ID is invalid"}` + + assertPOSTResponse(t, ctx, handler, method, url, string(requestBody), want, http.StatusBadRequest) + }) + + t.Run("returns error when country_code is not valid", func(t *testing.T) { + requestBody, err := json.Marshal(PostDisbursementRequest{ + Name: "disbursement 1", + CountryCode: "AAA", + AssetID: asset.ID, + WalletID: wallet.ID, + }) + require.NoError(t, err) + + want := `{"error":"country code is invalid"}` + + assertPOSTResponse(t, ctx, handler, method, url, string(requestBody), want, http.StatusBadRequest) + }) + + labels := monitor.DisbursementLabels{ + Asset: asset.Code, + Country: country.Code, + Wallet: wallet.Name, + } + + t.Run("returns error when disbursement name is not unique", func(t *testing.T) { + mMonitorService.On("MonitorCounters", monitor.DisbursementsCounterTag, labels.ToMap()).Return(nil).Once() + + requestBody, err := json.Marshal(PostDisbursementRequest{ + Name: "disbursement 1", + CountryCode: country.Code, + AssetID: asset.ID, + WalletID: wallet.ID, + }) + require.NoError(t, err) + + want := `{"error":"disbursement already exists"}` + + // create disbursement + assertPOSTResponse(t, ctx, handler, method, url, string(requestBody), "", http.StatusCreated) + mMonitorService.AssertExpectations(t) + // try creating again + assertPOSTResponse(t, ctx, handler, method, url, string(requestBody), want, http.StatusConflict) + }) + + t.Run("successfully create a disbursement", func(t *testing.T) { + mMonitorService.On("MonitorCounters", monitor.DisbursementsCounterTag, labels.ToMap()).Return(nil).Once() + + expectedName := "disbursement 2" + requestBody, err := json.Marshal(PostDisbursementRequest{ + Name: expectedName, + CountryCode: country.Code, + AssetID: asset.ID, + WalletID: wallet.ID, + }) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(ctx, method, url, strings.NewReader(string(requestBody))) + http.HandlerFunc(handler.PostDisbursement).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + + var actualDisbursement data.Disbursement + err = json.NewDecoder(resp.Body).Decode(&actualDisbursement) + + require.NoError(t, err) + assert.Equal(t, expectedName, actualDisbursement.Name) + assert.Equal(t, data.DraftDisbursementStatus, actualDisbursement.Status) + assert.Equal(t, asset, actualDisbursement.Asset) + assert.Equal(t, wallet, actualDisbursement.Wallet) + assert.Equal(t, country, actualDisbursement.Country) + assert.Equal(t, 1, len(actualDisbursement.StatusHistory)) + assert.Equal(t, data.DraftDisbursementStatus, actualDisbursement.StatusHistory[0].Status) + assert.Equal(t, user.ID, actualDisbursement.StatusHistory[0].UserID) + mMonitorService.AssertExpectations(t) + }) + + authManagerMock.AssertExpectations(t) +} + +func Test_DisbursementHandler_GetDisbursements_Errors(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &DisbursementHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ts := httptest.NewServer(http.HandlerFunc(handler.GetDisbursements)) + defer ts.Close() + + tests := []struct { + name string + queryParams map[string]string + expectedStatusCode int + expectedResponse string + }{ + { + name: "returns error when sort parameter is invalid", + queryParams: map[string]string{ + "sort": "invalid_sort", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"sort":"invalid sort field name"}}`, + }, + { + name: "returns error when direction is invalid", + queryParams: map[string]string{ + "direction": "invalid_direction", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"direction":"invalid sort order. valid values are 'asc' and 'desc'"}}`, + }, + { + name: "returns error when page is invalid", + queryParams: map[string]string{ + "page": "invalid_page", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"page":"parameter must be an integer"}}`, + }, + { + name: "returns error when page_limit is invalid", + queryParams: map[string]string{ + "page_limit": "invalid_page_limit", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"page_limit":"parameter must be an integer"}}`, + }, + { + name: "returns error when status is invalid", + queryParams: map[string]string{ + "status": "invalid_status", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"status":"invalid parameter. valid value is a comma separate list of statuses: draft, ready, started, paused, completed"}}`, + }, + { + name: "returns error when created_at_after is invalid", + queryParams: map[string]string{ + "created_at_after": "invalid_created_at_after", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"created_at_after":"invalid date format. valid format is 'YYYY-MM-DD'"}}`, + }, + { + name: "returns error when created_at_before is invalid", + queryParams: map[string]string{ + "created_at_before": "invalid_created_at_before", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"created_at_before":"invalid date format. valid format is 'YYYY-MM-DD'"}}`, + }, + { + name: "returns empty list when no expectedDisbursements are found", + queryParams: map[string]string{}, + expectedStatusCode: http.StatusOK, + expectedResponse: `{"data":[], "pagination":{"pages":0, "total":0}}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build the URL for the test request + url := buildURLWithQueryParams(ts.URL, "/disbursements", tc.queryParams) + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) + assert.JSONEq(t, tc.expectedResponse, string(respBody)) + }) + } +} + +func Test_DisbursementHandler_GetDisbursements_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &DisbursementHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ts := httptest.NewServer(http.HandlerFunc(handler.GetDisbursements)) + defer ts.Close() + + ctx := context.Background() + + // create fixtures + wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool) + asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC) + country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUKR) + + // create disbursements + disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + CreatedAt: time.Date(2022, 3, 21, 23, 40, 20, 1431, time.UTC), + }) + disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 2", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + CreatedAt: time.Date(2023, 2, 20, 23, 40, 20, 1431, time.UTC), + }) + disbursement3 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 3", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + CreatedAt: time.Date(2023, 3, 19, 23, 40, 20, 1431, time.UTC), + }) + disbursement4 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 4", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + CreatedAt: time.Date(2023, 4, 19, 23, 40, 20, 1431, time.UTC), + }) + + tests := []struct { + name string + queryParams map[string]string + expectedStatusCode int + expectedPagination httpresponse.PaginationInfo + expectedDisbursements []data.Disbursement + }{ + { + name: "fetch all disbursements without filters", + queryParams: map[string]string{}, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 4, + }, + expectedDisbursements: []data.Disbursement{*disbursement4, *disbursement3, *disbursement2, *disbursement1}, + }, + { + name: "fetch first page of disbursements with limit 1 and sort by name", + queryParams: map[string]string{ + "page": "1", + "page_limit": "1", + "sort": "name", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "/disbursements?direction=asc&page=2&page_limit=1&sort=name", + Prev: "", + Pages: 4, + Total: 4, + }, + expectedDisbursements: []data.Disbursement{*disbursement1}, + }, + { + name: "fetch second page of disbursements with limit 1 and sort by name", + queryParams: map[string]string{ + "page": "2", + "page_limit": "1", + "sort": "name", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "/disbursements?direction=asc&page=3&page_limit=1&sort=name", + Prev: "/disbursements?direction=asc&page=1&page_limit=1&sort=name", + Pages: 4, + Total: 4, + }, + expectedDisbursements: []data.Disbursement{*disbursement2}, + }, + { + name: "fetch last page of disbursements with limit 1 and sort by name", + queryParams: map[string]string{ + "page": "4", + "page_limit": "1", + "sort": "name", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "/disbursements?direction=asc&page=3&page_limit=1&sort=name", + Pages: 4, + Total: 4, + }, + expectedDisbursements: []data.Disbursement{*disbursement4}, + }, + { + name: "fetch last page of disbursements with limit 1 and sort by name", + queryParams: map[string]string{ + "page": "4", + "page_limit": "1", + "sort": "name", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "/disbursements?direction=asc&page=3&page_limit=1&sort=name", + Pages: 4, + Total: 4, + }, + expectedDisbursements: []data.Disbursement{*disbursement4}, + }, + { + name: "fetch disbursements with status draft", + queryParams: map[string]string{ + "status": "dRaFt", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 2, + }, + expectedDisbursements: []data.Disbursement{*disbursement4, *disbursement1}, + }, + { + name: "fetch disbursements with status draft and q=1", + queryParams: map[string]string{ + "status": "draft", + "q": "1", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 1, + }, + expectedDisbursements: []data.Disbursement{*disbursement1}, + }, + { + name: "fetch disbursements after 2023-01-01", + queryParams: map[string]string{ + "created_at_after": "2023-01-01", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 3, + }, + expectedDisbursements: []data.Disbursement{*disbursement4, *disbursement3, *disbursement2}, + }, + { + name: "fetch disbursements after 2023-01-01 and before 2023-03-20", + queryParams: map[string]string{ + "created_at_after": "2023-01-01", + "created_at_before": "2023-03-20", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 2, + }, + expectedDisbursements: []data.Disbursement{*disbursement3, *disbursement2}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build the URL for the test request + url := buildURLWithQueryParams(ts.URL, "/disbursements", tc.queryParams) + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + // Parse the response body + var actualResponse httpresponse.PaginatedResponse + err = json.NewDecoder(resp.Body).Decode(&actualResponse) + require.NoError(t, err) + + // Assert on the pagination data + assert.Equal(t, tc.expectedPagination, actualResponse.Pagination) + + // Parse the response data + var actualDisbursements []data.Disbursement + err = json.Unmarshal(actualResponse.Data, &actualDisbursements) + require.NoError(t, err) + + // Assert on the disbursements data + assert.Equal(t, tc.expectedDisbursements, actualDisbursements) + }) + } +} + +func Test_DisbursementHandler_PostDisbursementInstructions(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + mMonitorService := &monitor.MockMonitorService{} + + token := "token" + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + user := &auth.User{ + ID: "user-id", + Email: "email@email.com", + } + authManagerMock := &auth.AuthManagerMock{} + authManagerMock. + On("GetUser", mock.Anything, token). + Return(user, nil). + Run(func(args mock.Arguments) { + mockCtx := args.Get(0).(context.Context) + val := mockCtx.Value(middleware.TokenContextKey) + assert.Equal(t, token, val) + }) + + handler := &DisbursementHandler{ + Models: models, + MonitorService: mMonitorService, + DBConnectionPool: models.DBConnectionPool, + AuthManager: authManagerMock, + } + + // create fixtures + wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool) + asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC) + country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUKR) + + // create disbursement + draftDisbursement := data.CreateDraftDisbursementFixture(t, ctx, dbConnectionPool, handler.Models.Disbursements, data.Disbursement{ + Name: "disbursement1", + Asset: asset, + Country: country, + Wallet: wallet, + }) + + startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.StartedDisbursementStatus, + CreatedAt: time.Date(2022, 3, 21, 23, 40, 20, 1431, time.UTC), + }) + + maxCSVRecords := [][]string{ + {"phone", "id", "amount", "verification"}, + } + for i := 0; i < 10001; i++ { + maxCSVRecords = append(maxCSVRecords, []string{ + "+380445555555", "123456789", "100.5", "1990-01-01", + }) + } + + testCases := []struct { + name string + disbursementID string + fieldName string + csvRecords [][]string + expectedStatus int + expectedMessage string + }{ + { + name: "valid input", + disbursementID: draftDisbursement.ID, + csvRecords: [][]string{ + {"phone", "id", "amount", "verification"}, + {"+380445555555", "123456789", "100.5", "1990-01-01"}, + }, + expectedStatus: http.StatusOK, + expectedMessage: "File uploaded successfully", + }, + { + name: "invalid date of birth", + disbursementID: draftDisbursement.ID, + csvRecords: [][]string{ + {"phone", "id", "amount", "verification"}, + {"+380445555555", "123456789", "100.5", "1990/01/01"}, + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid date of birth format. Correct format: 1990-01-01", + }, + { + name: "invalid phone number", + disbursementID: draftDisbursement.ID, + csvRecords: [][]string{ + {"phone", "id", "amount", "verification"}, + {"380-12-345-678", "123456789", "100.5", "1990-01-01"}, + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "invalid phone format. Correct format: +380445555555", + }, + { + name: "invalid disbursement id", + disbursementID: "invalid", + expectedStatus: http.StatusBadRequest, + expectedMessage: "disbursement ID is invalid", + }, + { + name: "valid input", + disbursementID: draftDisbursement.ID, + fieldName: "instructions", + expectedStatus: http.StatusBadRequest, + expectedMessage: "could not parse file", + }, + { + name: "disbursement not in draft/ready starte", + disbursementID: startedDisbursement.ID, + expectedStatus: http.StatusBadRequest, + expectedMessage: "disbursement is not in draft or ready status", + }, + { + name: "disbursement not in draft/ready state", + disbursementID: startedDisbursement.ID, + expectedStatus: http.StatusBadRequest, + expectedMessage: "disbursement is not in draft or ready status", + }, + { + name: "error parsing header", + disbursementID: draftDisbursement.ID, + csvRecords: [][]string{ + {}, + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "could not parse file", + }, + { + name: "no instructions found in file", + disbursementID: draftDisbursement.ID, + csvRecords: [][]string{ + {"phone", "id", "amount", "date-of-birth"}, + }, + expectedStatus: http.StatusBadRequest, + expectedMessage: "no valid instructions found", + }, + { + name: "max instructions exceeded", + disbursementID: draftDisbursement.ID, + csvRecords: maxCSVRecords, + expectedStatus: http.StatusBadRequest, + expectedMessage: "number of instructions exceeds maximum of : 10000", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fileContent, err := createCSVFile(t, tc.csvRecords) + require.NoError(t, err) + + req, err := createInstructionsMultipartRequest(t, ctx, tc.fieldName, tc.disbursementID, fileContent) + require.NoError(t, err) + + // Record the response + rr := httptest.NewRecorder() + router := chi.NewRouter() + router.Post("/disbursements/{id}/instructions", handler.PostDisbursementInstructions) + router.ServeHTTP(rr, req) + + // Check the response status and message + assert.Equal(t, tc.expectedStatus, rr.Code) + assert.Contains(t, rr.Body.String(), tc.expectedMessage) + }) + + authManagerMock.AssertExpectations(t) + } +} + +func Test_DisbursementHandler_GetDisbursement(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &DisbursementHandler{ + Models: models, + DBConnectionPool: models.DBConnectionPool, + } + + r := chi.NewRouter() + r.Get("/disbursements/{id}", handler.GetDisbursement) + + // create disbursements + disbursement := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.DraftDisbursementStatus, + CreatedAt: time.Date(2022, 3, 21, 23, 40, 20, 1431, time.UTC), + }) + + tests := []struct { + name string + id string + expectedStatusCode int + expectedDisbursement data.Disbursement + expectedErrorMessage string + }{ + { + name: "disbursement not found", + id: "invalid", + expectedStatusCode: http.StatusNotFound, + expectedErrorMessage: "disbursement not found", + }, + { + name: "success", + id: disbursement.ID, + expectedStatusCode: http.StatusOK, + expectedDisbursement: *disbursement, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s", tc.id), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code == http.StatusOK { + var actualDisbursement data.Disbursement + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &actualDisbursement)) + require.Equal(t, tc.expectedDisbursement, actualDisbursement) + } else { + var actualErrorMessage httperror.HTTPError + require.Equal(t, tc.expectedStatusCode, rr.Code) + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &actualErrorMessage)) + require.Equal(t, tc.expectedErrorMessage, actualErrorMessage.Message) + } + }) + } +} + +func Test_DisbursementHandler_GetDisbursementReceivers(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &DisbursementHandler{ + Models: models, + DBConnectionPool: models.DBConnectionPool, + } + + r := chi.NewRouter() + r.Get("/disbursements/{id}/receivers", handler.GetDisbursementReceivers) + + // create fixtures + wallet := data.CreateWalletFixture(t, context.Background(), dbConnectionPool, + "My Wallet", + "https://mywallet.com", + "mywallet.com", + "mywallet://") + asset := data.CreateAssetFixture(t, context.Background(), dbConnectionPool, + "USDC", + "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, context.Background(), dbConnectionPool, + "FRA", + "France") + + // create disbursements + disbursementWithReceivers := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "disbursement with receivers", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + disbursementWithoutReceivers := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "disbursement without receivers", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + // create disbursement receivers + ctx := context.Background() + yesterday := time.Now().Add(-time.Hour * 24) + twoDaysAgo := time.Now().Add(-time.Hour * 48) + threeDaysAgo := time.Now().Add(-time.Hour * 72) + + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{CreatedAt: &yesterday}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{CreatedAt: &twoDaysAgo}) + receiver3 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{CreatedAt: &threeDaysAgo}) + + receiverWallet1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.DraftReceiversWalletStatus) + receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.DraftReceiversWalletStatus) + receiverWallet3 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver3.ID, wallet.ID, data.DraftReceiversWalletStatus) + + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, handler.Models.Payment, &data.Payment{ + ReceiverWallet: receiverWallet1, + Disbursement: disbursementWithReceivers, + Asset: *asset, + Amount: "100", + Status: data.SuccessPaymentStatus, + }) + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, handler.Models.Payment, &data.Payment{ + ReceiverWallet: receiverWallet2, + Disbursement: disbursementWithReceivers, + Asset: *asset, + Amount: "200", + Status: data.SuccessPaymentStatus, + }) + payment3 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, handler.Models.Payment, &data.Payment{ + ReceiverWallet: receiverWallet3, + Disbursement: disbursementWithReceivers, + Asset: *asset, + Amount: "300", + Status: data.SuccessPaymentStatus, + }) + + expectedDisbursementReceivers := []data.DisbursementReceiver{ + { + ID: receiver3.ID, + PhoneNumber: receiver3.PhoneNumber, + Email: *receiver3.Email, + ExternalID: receiver3.ExternalID, + ReceiverWallet: receiverWallet3, + Payment: payment3, + CreatedAt: *receiver3.CreatedAt, + UpdatedAt: *receiver3.UpdatedAt, + }, + { + ID: receiver2.ID, + PhoneNumber: receiver2.PhoneNumber, + Email: *receiver2.Email, + ExternalID: receiver2.ExternalID, + ReceiverWallet: receiverWallet2, + Payment: payment2, + CreatedAt: *receiver2.CreatedAt, + UpdatedAt: *receiver2.UpdatedAt, + }, + { + ID: receiver1.ID, + PhoneNumber: receiver1.PhoneNumber, + Email: *receiver1.Email, + ExternalID: receiver1.ExternalID, + ReceiverWallet: receiverWallet1, + Payment: payment1, + CreatedAt: *receiver1.CreatedAt, + UpdatedAt: *receiver1.UpdatedAt, + }, + } + + t.Run("disbursement doesn't exist", func(t *testing.T) { + id := "5e1f1c7f5b6c9c0001c1b1b1" + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s/receivers", id), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNotFound, rr.Code) + }) + + t.Run("disbursement without receivers", func(t *testing.T) { + id := disbursementWithoutReceivers.ID + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s/receivers", id), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + + var actualResponse httpresponse.PaginatedResponse + require.NoError(t, json.NewDecoder(rr.Body).Decode(&actualResponse)) + require.Equal(t, httpresponse.NewEmptyPaginatedResponse(), actualResponse) + }) + + t.Run("disbursement with receivers", func(t *testing.T) { + id := disbursementWithReceivers.ID + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s/receivers", id), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + + var actualResponse httpresponse.PaginatedResponse + require.NoError(t, json.NewDecoder(rr.Body).Decode(&actualResponse)) + require.Equal(t, 3, actualResponse.Pagination.Total) + require.Equal(t, 1, actualResponse.Pagination.Pages) + + var actualDisbursementReceivers []data.DisbursementReceiver + require.NoError(t, json.NewDecoder(bytes.NewReader(actualResponse.Data)).Decode(&actualDisbursementReceivers)) + + for i, actual := range actualDisbursementReceivers { + require.Equal(t, expectedDisbursementReceivers[i].ID, actual.ID) + require.Equal(t, expectedDisbursementReceivers[i].PhoneNumber, actual.PhoneNumber) + require.Equal(t, expectedDisbursementReceivers[i].Email, actual.Email) + require.Equal(t, expectedDisbursementReceivers[i].ExternalID, actual.ExternalID) + require.Equal(t, expectedDisbursementReceivers[i].ReceiverWallet.ID, actual.ReceiverWallet.ID) + require.Equal(t, expectedDisbursementReceivers[i].Payment.ID, actual.Payment.ID) + } + }) +} + +func Test_DisbursementHandler_PatchDisbursementStatus(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + token := "token" + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + userID := "valid-user-id" + user := &auth.User{ + ID: userID, + Email: "email@email.com", + } + require.NotNil(t, user) + authManagerMock := &auth.AuthManagerMock{} + + handler := &DisbursementHandler{ + Models: models, + DBConnectionPool: models.DBConnectionPool, + AuthManager: authManagerMock, + } + + r := chi.NewRouter() + r.Patch("/disbursements/{id}/status", handler.PatchDisbursementStatus) + + readyStatusHistory := []data.DisbursementStatusHistoryEntry{ + { + Status: data.DraftDisbursementStatus, + UserID: userID, + }, + { + Status: data.ReadyDisbursementStatus, + UserID: userID, + }, + } + // create disbursements + draftDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "draft disbursement", + Status: data.DraftDisbursementStatus, + }) + + reqBody := bytes.NewBuffer(nil) + t.Run("invalid body", func(t *testing.T) { + id := draftDisbursement.ID + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", id), reqBody) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Contains(t, rr.Body.String(), "invalid request body") + }) + + t.Run("invalid status", func(t *testing.T) { + id := "5e1f1c7f5b6c9c0001c1b1b1" + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "INVALID"}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", id), reqBody) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Contains(t, rr.Body.String(), "invalid status") + }) + + t.Run("disbursement not ready to start", func(t *testing.T) { + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Started"}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", draftDisbursement.ID), reqBody) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Contains(t, rr.Body.String(), services.ErrDisbursementNotReadyToStart.Error()) + }) + + t.Run("disbursement can't be started by creator", func(t *testing.T) { + data.EnableDisbursementApproval(t, ctx, handler.Models.Organizations) + defer data.DisableDisbursementApproval(t, ctx, handler.Models.Organizations) + + readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "ready disbursement #1", + Status: data.ReadyDisbursementStatus, + StatusHistory: readyStatusHistory, + }) + + authManagerMock. + On("GetUser", mock.Anything, token). + Return(user, nil). + Once() + + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Started"}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", readyDisbursement.ID), reqBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusForbidden, rr.Code) + require.Contains(t, rr.Body.String(), "Disbursement can't be started by its creator. Approval by another user is required") + }) + + t.Run("disbursement can be started by approver who is not a creator", func(t *testing.T) { + data.EnableDisbursementApproval(t, ctx, handler.Models.Organizations) + defer data.DisableDisbursementApproval(t, ctx, handler.Models.Organizations) + + readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "ready disbursement #2", + Status: data.ReadyDisbursementStatus, + StatusHistory: readyStatusHistory, + }) + + approverUser := &auth.User{ + ID: "valid-approver-user-id", + Email: "approver@mail.org", + } + + authManagerMock. + On("GetUser", mock.Anything, token). + Return(approverUser, nil). + Once() + + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Started"}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", readyDisbursement.ID), reqBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Contains(t, rr.Body.String(), "Disbursement started") + }) + + t.Run("disbursement started - then paused", func(t *testing.T) { + authManagerMock. + On("GetUser", mock.Anything, token). + Return(user, nil). + Twice() + readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, handler.Models.Disbursements, &data.Disbursement{ + Name: "ready disbursement #3", + Status: data.ReadyDisbursementStatus, + StatusHistory: readyStatusHistory, + }) + + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Started"}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", readyDisbursement.ID), reqBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Contains(t, rr.Body.String(), "Disbursement started") + + // check disbursement status + disbursement, err := handler.Models.Disbursements.Get(context.Background(), models.DBConnectionPool, readyDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.StartedDisbursementStatus, disbursement.Status) + + // pause disbursement + err = json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Paused"}) + require.NoError(t, err) + + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", readyDisbursement.ID), reqBody) + require.NoError(t, err) + rr = httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Contains(t, rr.Body.String(), "Disbursement paused") + + // check disbursement status + disbursement, err = handler.Models.Disbursements.Get(context.Background(), models.DBConnectionPool, readyDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.PausedDisbursementStatus, disbursement.Status) + }) + + t.Run("disbursement can't be paused", func(t *testing.T) { + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Paused"}) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", draftDisbursement.ID), reqBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Contains(t, rr.Body.String(), services.ErrDisbursementNotReadyToPause.Error()) + }) + + t.Run("disbursement status can't be changed", func(t *testing.T) { + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "Completed"}) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", draftDisbursement.ID), reqBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Contains(t, rr.Body.String(), services.ErrDisbursementStatusCantBeChanged.Error()) + }) + + t.Run("disbursement doesn't exist", func(t *testing.T) { + id := "5e1f1c7f5b6c9c0001c1b1b1" + err := json.NewEncoder(reqBody).Encode(PatchDisbursementStatusRequest{Status: "STARTED"}) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("/disbursements/%s/status", id), reqBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNotFound, rr.Code) + require.Contains(t, rr.Body.String(), services.ErrDisbursementNotFound.Error()) + }) + + authManagerMock.AssertExpectations(t) +} + +func Test_DisbursementHandler_GetDisbursementInstructions(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := data.NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + handler := &DisbursementHandler{ + Models: models, + DBConnectionPool: models.DBConnectionPool, + } + + r := chi.NewRouter() + r.Get("/disbursements/{id}/instructions", handler.GetDisbursementInstructions) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{}) + require.NotNil(t, disbursement) + + t.Run("disbursement doesn't exist", func(t *testing.T) { + id := "9e0ff65f-f6e9-46e9-bf03-dc46723e3bfb" + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s/instructions", id), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNotFound, rr.Code) + require.Contains(t, rr.Body.String(), services.ErrDisbursementNotFound.Error()) + }) + + t.Run("disbursement has no instructions", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s/instructions", disbursement.ID), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNotFound, rr.Code) + require.Contains(t, rr.Body.String(), fmt.Sprintf("disbursement %s has no instructions file", disbursement.ID)) + }) + + t.Run("disbursement has instructions", func(t *testing.T) { + disbursementFileContent := data.CreateInstructionsFixture(t, []*data.DisbursementInstruction{ + {Phone: "1234567890", ID: "1", Amount: "123.12", VerificationValue: "1995-02-20"}, + {Phone: "0987654321", ID: "2", Amount: "321", VerificationValue: "1974-07-19"}, + {Phone: "0987654321", ID: "3", Amount: "321", VerificationValue: "1974-07-19"}, + }) + + err := models.Disbursements.Update(ctx, &data.DisbursementUpdate{ + ID: disbursement.ID, + FileContent: disbursementFileContent, + FileName: "instructions.csv", + }) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/disbursements/%s/instructions", disbursement.ID), nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "text/csv", rr.Header().Get("Content-Type")) + require.Equal(t, "attachment; filename=\"instructions.csv\"", rr.Header().Get("Content-Disposition")) + require.Equal(t, string(disbursementFileContent), rr.Body.String()) + }) +} + +func createCSVFile(t *testing.T, records [][]string) (io.Reader, error) { + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + for _, record := range records { + err := writer.Write(record) + require.NoError(t, err) + } + writer.Flush() + return &buf, nil +} + +func createInstructionsMultipartRequest(t *testing.T, ctx context.Context, fieldName, disbursementID string, fileContent io.Reader) (*http.Request, error) { + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + + if fieldName == "" { + fieldName = "file" + } + + part, err := writer.CreateFormFile(fieldName, "instructions.csv") + require.NoError(t, err) + + _, err = io.Copy(part, fileContent) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + url := fmt.Sprintf("/disbursements/%s/instructions", disbursementID) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &buf) + require.NoError(t, err) + req.Header.Set("Content-Type", writer.FormDataContentType()) + return req, nil +} + +func assertPOSTResponse(t *testing.T, ctx context.Context, handler *DisbursementHandler, method, url, requestBody, want string, expectedStatus int) { + rr := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(ctx, method, url, strings.NewReader(requestBody)) + http.HandlerFunc(handler.PostDisbursement).ServeHTTP(rr, req) + + resp := rr.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, expectedStatus, resp.StatusCode) + + if want != "" { + assert.JSONEq(t, want, string(respBody)) + } +} + +func buildURLWithQueryParams(baseURL, endpoint string, queryParams map[string]string) string { + url := baseURL + endpoint + if len(queryParams) > 0 { + url += "?" + for k, v := range queryParams { + url += fmt.Sprintf("%s=%s&", k, v) + } + url = strings.TrimSuffix(url, "&") + } + return url +} diff --git a/internal/serve/httphandler/forgot_password_handler.go b/internal/serve/httphandler/forgot_password_handler.go new file mode 100644 index 000000000..499ded237 --- /dev/null +++ b/internal/serve/httphandler/forgot_password_handler.go @@ -0,0 +1,141 @@ +package httphandler + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +const forgotPasswordMessageTitle = "Reset Account Password" + +// ForgotPasswordHandler searches for the user that is requesting a password reset +// and sends an email with a link to access the password reset page. +type ForgotPasswordHandler struct { + AuthManager auth.AuthManager + MessengerClient message.MessengerClient + UIBaseURL string + Models *data.Models + ReCAPTCHAValidator validators.ReCAPTCHAValidator + ReCAPTCHAEnabled bool +} + +type ForgotPasswordRequest struct { + Email string `json:"email"` + ReCAPTCHAToken string `json:"recaptcha_token"` +} + +type ForgotPasswordResponseBody struct { + Message string `json:"message"` +} + +// ServeHTTP implements the http.Handler interface. +func (h ForgotPasswordHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var forgotPasswordRequest ForgotPasswordRequest + + err := json.NewDecoder(r.Body).Decode(&forgotPasswordRequest) + if err != nil { + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + ctx := r.Context() + + if h.ReCAPTCHAEnabled { + // validating reCAPTCHA Token + isValid, recaptchaErr := h.ReCAPTCHAValidator.IsTokenValid(ctx, forgotPasswordRequest.ReCAPTCHAToken) + if recaptchaErr != nil { + httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", recaptchaErr, nil).Render(w) + return + } + + if !isValid { + log.Ctx(ctx).Errorf("reCAPTCHA token is invalid for request with email %s", utils.TruncateString(forgotPasswordRequest.Email, 3)) + httperror.BadRequest("reCAPTCHA token invalid", nil, nil).Render(w) + return + } + } + + // validate request + v := validators.NewValidator() + + v.Check(forgotPasswordRequest.Email != "", "email", "email is required") + + if v.HasErrors() { + httperror.BadRequest("request invalid", err, v.Errors).Render(w) + return + } + + resetToken, err := h.AuthManager.ForgotPassword(ctx, forgotPasswordRequest.Email) + // if we don't find the user by email, we just return an ok response + // to prevent malicious client from searching accounts in the system + if err != nil { + if errors.Is(err, auth.ErrUserNotFound) { + log.Ctx(ctx).Errorf("error in forgot password handler, email not found: %s", forgotPasswordRequest.Email) + } else if errors.Is(err, auth.ErrUserHasValidToken) { + log.Ctx(ctx).Errorf("error in forgot password handler, user has a valid token") + } else { + httperror.InternalError(ctx, "", err, nil).Render(w) + return + } + } + + if err == nil { + organization, err := h.Models.Organizations.Get(ctx) + if err != nil { + err = fmt.Errorf("error getting organization data: %w", err) + httperror.InternalError(ctx, "", err, nil).Render(w) + return + } + + resetPasswordLink, err := url.JoinPath(h.UIBaseURL, "reset-password") + if err != nil { + err = fmt.Errorf("error getting reset password link: %w", err) + log.Ctx(ctx).Error(err) + httperror.InternalError(ctx, "", err, nil).Render(w) + return + } + + forgotPasswordData := htmltemplate.ForgotPasswordMessageTemplate{ + ResetToken: resetToken, + ResetPasswordLink: resetPasswordLink, + OrganizationName: organization.Name, + } + messageContent, err := htmltemplate.ExecuteHTMLTemplateForForgotPasswordMessage(forgotPasswordData) + if err != nil { + err = fmt.Errorf("error executing forgot password message template: %w", err) + httperror.InternalError(ctx, "", err, nil).Render(w) + return + } + + msg := message.Message{ + ToEmail: forgotPasswordRequest.Email, + Title: forgotPasswordMessageTitle, + Message: messageContent, + } + err = h.MessengerClient.SendMessage(msg) + if err != nil { + err = fmt.Errorf("error sending forgot password email for email %s: %w", forgotPasswordRequest.Email, err) + httperror.InternalError(ctx, "", err, nil).Render(w) + return + } + } + + responseBody := ForgotPasswordResponseBody{ + Message: "Password reset requested. If the email is registered, you'll receive a reset link shortly. Check your inbox and spam folders.", + } + + httpjson.RenderStatus(w, http.StatusOK, responseBody, httpjson.JSON) +} diff --git a/internal/serve/httphandler/forgot_password_handler_test.go b/internal/serve/httphandler/forgot_password_handler_test.go new file mode 100644 index 000000000..92fc8d934 --- /dev/null +++ b/internal/serve/httphandler/forgot_password_handler_test.go @@ -0,0 +1,376 @@ +package httphandler + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + urllib "net/url" + "strings" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stretchr/testify/mock" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ForgotPasswordHandler(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + const url = "/forgot-password" + + authenticatorMock := &auth.AuthenticatorMock{} + reCAPTCHAValidatorMock := &validators.ReCAPTCHAValidatorMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + ) + + uiBaseURL := "https://sdp.com" + messengerClientMock := &message.MessengerClientMock{} + handler := &ForgotPasswordHandler{ + AuthManager: authManager, + MessengerClient: messengerClientMock, + Models: models, + UIBaseURL: uiBaseURL, + ReCAPTCHAValidator: reCAPTCHAValidatorMock, + ReCAPTCHAEnabled: true, + } + + t.Run("Should return http status 200 on a valid request", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com" , + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + authenticatorMock. + On("ForgotPassword", req.Context(), "valid@email.com"). + Return("resetToken", nil). + Once() + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + resetPasswordLink, err := urllib.JoinPath(uiBaseURL, "reset-password") + require.NoError(t, err) + + content, err := htmltemplate.ExecuteHTMLTemplateForForgotPasswordMessage(htmltemplate.ForgotPasswordMessageTemplate{ + ResetToken: "resetToken", + ResetPasswordLink: resetPasswordLink, + OrganizationName: "MyCustomAid", + }) + require.NoError(t, err) + + msg := message.Message{ + ToEmail: "valid@email.com", + Title: forgotPasswordMessageTitle, + Message: content, + } + messengerClientMock. + On("SendMessage", msg). + Return(nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Should return http status 500 when the reset password link is invalid", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com" , + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + authenticatorMock. + On("ForgotPassword", req.Context(), "valid@email.com"). + Return("resetToken", nil). + Once() + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + http.HandlerFunc(ForgotPasswordHandler{ + AuthManager: authManager, + MessengerClient: messengerClientMock, + Models: models, + UIBaseURL: "%invalid%", + ReCAPTCHAValidator: reCAPTCHAValidatorMock, + ReCAPTCHAEnabled: true, + }.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + }) + + t.Run("Should return an http status ok even if the email is not found", func(t *testing.T) { + requestBody := ` + { + "email": "not_found@email.com" , + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + authenticatorMock. + On("ForgotPassword", req.Context(), "not_found@email.com"). + Return("", auth.ErrUserNotFound). + Once() + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Should return an http status ok even if the user has a valid token", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com" , + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + authenticatorMock. + On("ForgotPassword", req.Context(), "valid@email.com"). + Return("", auth.ErrUserHasValidToken). + Once() + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Should require email param", func(t *testing.T) { + requestBody := ` + { + "email": "", + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := ` + { + "error":"request invalid", + "extras": { + "email":"email is required" + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, expectedBody, string(respBody)) + }) + + t.Run("Should return http status 500 when error sending email", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com", + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + authenticatorMock. + On("ForgotPassword", req.Context(), "valid@email.com"). + Return("resetToken", nil). + Once() + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + resetPasswordLink, err := urllib.JoinPath(uiBaseURL, "reset-password") + require.NoError(t, err) + + content, err := htmltemplate.ExecuteHTMLTemplateForForgotPasswordMessage(htmltemplate.ForgotPasswordMessageTemplate{ + ResetToken: "resetToken", + ResetPasswordLink: resetPasswordLink, + OrganizationName: "MyCustomAid", + }) + require.NoError(t, err) + + msg := message.Message{ + ToEmail: "valid@email.com", + Title: forgotPasswordMessageTitle, + Message: content, + } + messengerClientMock. + On("SendMessage", msg). + Return(errors.New("unexpected error")). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := ` + { + "error": "An internal error occurred while processing this request." + } + ` + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, expectedBody, string(respBody)) + }) + + t.Run("Should return http status 500 when authenticator fails", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com", + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + authenticatorMock. + On("ForgotPassword", req.Context(), "valid@email.com"). + Return("", errors.New("unexpected error")). + Once() + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(true, nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := ` + { + "error": "An internal error occurred while processing this request." + } + ` + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, expectedBody, string(respBody)) + }) + + t.Run("Should return http status 500 when reCAPTCHA validator returns an error", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com" , + "recaptcha_token": "validToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "validToken"). + Return(false, errors.New("error requesting verify reCAPTCHA token")). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + wantsBody := ` + { + "error": "Cannot validate reCAPTCHA token" + } + ` + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("Should return http status 400 when reCAPTCHA token is invalid", func(t *testing.T) { + requestBody := ` + { + "email": "valid@email.com" , + "recaptcha_token": "invalidToken" + }` + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(requestBody)) + require.NoError(t, err) + + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "invalidToken"). + Return(false, nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + wantsBody := ` + { + "error": "reCAPTCHA token invalid" + } + ` + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + authenticatorMock.AssertExpectations(t) + messengerClientMock.AssertExpectations(t) + reCAPTCHAValidatorMock.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/health_handler.go b/internal/serve/httphandler/health_handler.go new file mode 100644 index 000000000..da73ed2f6 --- /dev/null +++ b/internal/serve/httphandler/health_handler.go @@ -0,0 +1,49 @@ +package httphandler + +import ( + "net/http" + + "github.com/stellar/go/support/render/httpjson" +) + +// Status indicates whether the service is health or not. +type Status string + +const ( + // StatusPass indicates that the service is healthy. + StatusPass Status = "pass" + // StatusFail indicates that the service is unhealthy. + StatusFail Status = "fail" +) + +// HealthResponse follows the health check response format for HTTP APIs, +// based on the format defined in the draft IETF network working group +// standard, Health Check Response Format for HTTP APIs. +// +// https://datatracker.ietf.org/doc/html/draft-inadarei-api-health-check-06#name-api-health-response +type HealthResponse struct { + Status Status `json:"status"` + Version string `json:"version,omitempty"` + ServiceID string `json:"service_id,omitempty"` + ReleaseID string `json:"release_id,omitempty"` +} + +// HealthHandler implements a simple handler that returns the health response. +type HealthHandler struct { + Version string + ServiceID string + ReleaseID string +} + +// ServeHTTP implements the http.Handler interface. +func (h HealthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + response := HealthResponse{ + Status: StatusPass, + Version: h.Version, + ServiceID: h.ServiceID, + ReleaseID: h.ReleaseID, + } + + // TODO: after we have a DB connection, we should check if the DB is healthy + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) +} diff --git a/internal/serve/httphandler/health_handler_test.go b/internal/serve/httphandler/health_handler_test.go new file mode 100644 index 000000000..4dd39e800 --- /dev/null +++ b/internal/serve/httphandler/health_handler_test.go @@ -0,0 +1,38 @@ +package httphandler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// test HealthHandler: +func TestHealthHandler(t *testing.T) { + // setup + r := chi.NewRouter() + r.Get("/health", HealthHandler{ + Version: "x.y.z", + ServiceID: "my-api", + ReleaseID: "1234567890abcdef", + }.ServeHTTP) + + // test + req, err := http.NewRequest("GET", "/health", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + wantJson := `{ + "status": "pass", + "version": "x.y.z", + "service_id": "my-api", + "release_id": "1234567890abcdef" + }` + assert.JSONEq(t, wantJson, rr.Body.String()) +} diff --git a/internal/serve/httphandler/list_roles_handler.go b/internal/serve/httphandler/list_roles_handler.go new file mode 100644 index 000000000..6356b143a --- /dev/null +++ b/internal/serve/httphandler/list_roles_handler.go @@ -0,0 +1,16 @@ +package httphandler + +import ( + "net/http" + + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type ListRolesHandler struct{} + +// GetRoles retrieves all the users roles available +func (h ListRolesHandler) GetRoles(rw http.ResponseWriter, req *http.Request) { + roles := map[string][]data.UserRole{"roles": data.GetAllRoles()} + httpjson.Render(rw, roles, httpjson.JSON) +} diff --git a/internal/serve/httphandler/list_roles_handler_test.go b/internal/serve/httphandler/list_roles_handler_test.go new file mode 100644 index 000000000..f47055fab --- /dev/null +++ b/internal/serve/httphandler/list_roles_handler_test.go @@ -0,0 +1,32 @@ +package httphandler + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListRoles(t *testing.T) { + r := chi.NewRouter() + + r.Get("/users/roles", ListRolesHandler{}.GetRoles) + + req, err := http.NewRequest(http.MethodGet, "/users/roles", nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.JSONEq(t, `{"roles": ["owner", "financial_controller", "developer", "business"]}`, string(respBody)) +} diff --git a/internal/serve/httphandler/login_handler.go b/internal/serve/httphandler/login_handler.go new file mode 100644 index 000000000..ef524260d --- /dev/null +++ b/internal/serve/httphandler/login_handler.go @@ -0,0 +1,177 @@ +package httphandler + +import ( + "errors" + "fmt" + "net/http" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +const mfaMessageTitle = "Verification code to access your account" + +type LoginRequest struct { + Email string `json:"email"` + Password string `json:"password"` + ReCAPTCHAToken string `json:"recaptcha_token"` +} + +func (r LoginRequest) validate() *httperror.HTTPError { + validator := validators.NewValidator() + + validator.Check(r.Email != "", "email", "email is required") + validator.Check(r.Password != "", "password", "password is required") + + if validator.HasErrors() { + return httperror.BadRequest("Request invalid", nil, validator.Errors) + } + + return nil +} + +type LoginResponse struct { + Token string `json:"token"` +} + +type LoginHandler struct { + AuthManager auth.AuthManager + ReCAPTCHAValidator validators.ReCAPTCHAValidator + MessengerClient message.MessengerClient + Models *data.Models + ReCAPTCHAEnabled bool + MFAEnabled bool +} + +func (h LoginHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + var reqBody LoginRequest + if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { + err = fmt.Errorf("decoding the request body: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + if err := reqBody.validate(); err != nil { + err.Render(rw) + return + } + + if h.ReCAPTCHAEnabled { + // validating reCAPTCHA Token + isValid, err := h.ReCAPTCHAValidator.IsTokenValid(ctx, reqBody.ReCAPTCHAToken) + if err != nil { + httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", err, nil).Render(rw) + return + } + + if !isValid { + log.Ctx(ctx).Errorf("reCAPTCHA token is invalid for request with email %s", utils.TruncateString(reqBody.Email, 3)) + httperror.BadRequest("reCAPTCHA token invalid", nil, nil).Render(rw) + return + } + } + + token, err := h.AuthManager.Authenticate(ctx, reqBody.Email, reqBody.Password) + if errors.Is(err, auth.ErrInvalidCredentials) { + httperror.Unauthorized("", err, map[string]interface{}{"details": "Incorrect email or password"}).Render(rw) + return + } + if err != nil { + log.Ctx(ctx).Errorf("error authenticating user with email %s: %s", utils.TruncateString(reqBody.Email, 3), err) + httperror.InternalError(ctx, "Cannot authenticate user credentials", err, nil).Render(rw) + return + } + + if !h.MFAEnabled { + httpjson.RenderStatus(rw, http.StatusOK, LoginResponse{Token: token}, httpjson.JSON) + return + } + + // πŸ”’ Handling MFA + user, err := h.AuthManager.GetUser(ctx, token) + if err != nil { + log.Ctx(ctx).Errorf("error getting user with email %s: %s", utils.TruncateString(reqBody.Email, 3), err) + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + deviceID := req.Header.Get(DeviceIDHeader) + if deviceID == "" { + httperror.BadRequest("Device-ID header is required", nil, nil).Render(rw) + return + } + + isRemembered, err := h.AuthManager.MFADeviceRemembered(ctx, deviceID, user.ID) + if err != nil { + log.Ctx(ctx).Errorf("error checking if device is remembered for user with email %s: %s", utils.TruncateString(reqBody.Email, 3), err.Error()) + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + if isRemembered { + httpjson.RenderStatus(rw, http.StatusOK, LoginResponse{Token: token}, httpjson.JSON) + return + } + + // Get the MFA code for the user + code, err := h.AuthManager.GetMFACode(ctx, deviceID, user.ID) + if err != nil { + log.Ctx(ctx).Errorf("error getting MFA code for user with email %s: %s", utils.TruncateString(reqBody.Email, 3), err.Error()) + httperror.InternalError(ctx, "Cannot get MFA code", err, nil).Render(rw) + return + } + + if code == "" { + log.Ctx(ctx).Errorf("MFA code for user with email %s is empty", utils.TruncateString(reqBody.Email, 3)) + httperror.InternalError(ctx, "Cannot get MFA code", err, nil).Render(rw) + return + } + + organization, err := h.Models.Organizations.Get(ctx) + if err != nil { + err = fmt.Errorf("error getting organization data: %w", err) + log.Ctx(ctx).Error(err) + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + msgTemplate := htmltemplate.MFAMessageTemplate{ + MFACode: code, + OrganizationName: organization.Name, + } + msgContent, err := htmltemplate.ExecuteHTMLTemplateForMFAMessage(msgTemplate) + if err != nil { + err = fmt.Errorf("error executing mfa message template: %w", err) + log.Ctx(ctx).Error(err) + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + msg := message.Message{ + ToEmail: user.Email, + Title: mfaMessageTitle, + Message: msgContent, + } + err = h.MessengerClient.SendMessage(msg) + if err != nil { + err = fmt.Errorf("error sending mfa code for email %s: %w", user.Email, err) + log.Ctx(ctx).Error(err) + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"message": "MFA code sent to email. Check your inbox and spam folders."}, httpjson.JSON) +} diff --git a/internal/serve/httphandler/login_handler_test.go b/internal/serve/httphandler/login_handler_test.go new file mode 100644 index 000000000..605be0cf2 --- /dev/null +++ b/internal/serve/httphandler/login_handler_test.go @@ -0,0 +1,620 @@ +package httphandler + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_LoginRequest_validate(t *testing.T) { + lr := LoginRequest{ + Email: "", + Password: "", + ReCAPTCHAToken: "", + } + + extras := map[string]interface{}{"email": "email is required", "password": "password is required"} + expectedErr := httperror.BadRequest("Request invalid", nil, extras) + + err := lr.validate() + assert.Equal(t, expectedErr, err) + + lr = LoginRequest{ + Email: "email@email.com", + Password: "", + ReCAPTCHAToken: "XyZ", + } + + extras = map[string]interface{}{"password": "password is required"} + expectedErr = httperror.BadRequest("Request invalid", nil, extras) + + err = lr.validate() + assert.Equal(t, expectedErr, err) +} + +func Test_LoginHandler(t *testing.T) { + r := chi.NewRouter() + + authenticatorMock := &auth.AuthenticatorMock{} + jwtManagerMock := &auth.JWTManagerMock{} + roleManagerMock := &auth.RoleManagerMock{} + reCAPTCHAValidator := &validators.ReCAPTCHAValidatorMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomRoleManagerOption(roleManagerMock), + ) + + handler := &LoginHandler{ + AuthManager: authManager, + ReCAPTCHAValidator: reCAPTCHAValidator, + ReCAPTCHAEnabled: true, + } + + const url = "/login" + + t.Run("returns error when body is invalid", func(t *testing.T) { + r.Post(url, handler.ServeHTTP) + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{}`)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Request invalid", + "extras": { + "email": "email is required", + "password": "password is required" + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + req, err = http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email": "testuser"}`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "password": "password is required" + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err = http.NewRequest(http.MethodPost, url, strings.NewReader(`"invalid"`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = `{"error": "The request was invalid in some way."}` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + assert.Contains(t, buf.String(), "decoding the request body") + }) + + t.Run("returns error when an unexpected error occurs validating the credentials", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + + authenticatorMock. + On("ValidateCredentials", mock.Anything, "testuser", "pass1234"). + Return(nil, errors.New("unexpected error")). + Once() + + r.Post(url, handler.ServeHTTP) + + reqBody := ` + { + "email": "testuser", + "password": "pass1234", + "recaptcha_token": "XyZ" + } + ` + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Cannot authenticate user credentials" + } + ` + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + assert.Contains(t, buf.String(), "Cannot authenticate user credentials") + }) + + t.Run("returns error when the credentials are incorrect", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + + authenticatorMock. + On("ValidateCredentials", mock.Anything, "testuser", "pass1234"). + Return(nil, auth.ErrInvalidCredentials). + Once() + + r.Post(url, handler.ServeHTTP) + + reqBody := ` + { + "email": "testuser", + "password": "pass1234", + "recaptcha_token": "XyZ" + } + ` + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Not authorized.", + "extras": { + "details": "Incorrect email or password" + } + } + ` + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("returns error when unable to validate recaptcha", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(false, errors.New("error requesting verify reCAPTCHA token")). + Once() + + r.Post(url, handler.ServeHTTP) + + reqBody := ` + { + "email": "testuser", + "password": "pass1234", + "recaptcha_token": "XyZ" + } + ` + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Cannot validate reCAPTCHA token" + } + ` + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("returns error when recaptcha token is invalid", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(false, nil). + Once() + + r.Post(url, handler.ServeHTTP) + + reqBody := ` + { + "email": "testuser", + "password": "pass1234", + "recaptcha_token": "XyZ" + } + ` + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "reCAPTCHA token invalid" + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("returns the token correctly", func(t *testing.T) { + user := &auth.User{ + ID: "user-ID", + Email: "email", + } + + authenticatorMock. + On("ValidateCredentials", mock.Anything, "testuser", "pass1234"). + Return(user, nil). + Once() + + roleManagerMock. + On("GetUserRoles", mock.Anything, user). + Return([]string{"role1"}, nil). + Once() + + jwtManagerMock. + On("GenerateToken", mock.Anything, user, mock.AnythingOfType("time.Time")). + Return("token123", nil). + Once() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + + r.Post(url, handler.ServeHTTP) + + reqBody := ` + { + "email": "testuser", + "password": "pass1234", + "recaptcha_token": "XyZ" + } + ` + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"token": "token123"}`, string(respBody)) + }) + + authenticatorMock.AssertExpectations(t) + jwtManagerMock.AssertExpectations(t) + roleManagerMock.AssertExpectations(t) + reCAPTCHAValidator.AssertExpectations(t) +} + +func Test_LoginHandlerr_ServeHTTP_MFA(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := data.NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + authenticatorMock := &auth.AuthenticatorMock{} + jwtManagerMock := &auth.JWTManagerMock{} + roleManagerMock := &auth.RoleManagerMock{} + mfaManagerMock := &auth.MFAManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomRoleManagerOption(roleManagerMock), + auth.WithCustomMFAManagerOption(mfaManagerMock), + ) + messengerClientMock := &message.MessengerClientMock{} + loginHandler := &LoginHandler{ + AuthManager: authManager, + ReCAPTCHAEnabled: false, + MFAEnabled: true, + Models: models, + MessengerClient: messengerClientMock, + } + + user := &auth.User{ + ID: "userID", + Email: "testuser@mail.com", + } + authenticatorMock. + On("ValidateCredentials", mock.Anything, "testuser@mail.com", "pass1234"). + Return(user, nil) + roleManagerMock. + On("GetUserRoles", mock.Anything, user). + Return([]string{"role1"}, nil) + jwtManagerMock. + On("GenerateToken", mock.Anything, user, mock.AnythingOfType("time.Time")). + Return("token123", nil) + jwtManagerMock. + On("ValidateToken", mock.Anything, "token123"). + Return(true, nil) + jwtManagerMock. + On("GetUserFromToken", mock.Anything, "token123"). + Return(user, nil) + + deviceID := "safari-xyz" + + t.Run("error getting user from token", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(nil, errors.New("weird error happened")). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "An internal error occurred while processing this request") + }) + + t.Run("error when deviceID header is empty", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusBadRequest, rw.Code) + require.Contains(t, rw.Body.String(), "Device-ID header is required") + }) + + t.Run("error validating MFA device", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + mfaManagerMock. + On("MFADeviceRemembered", mock.Anything, deviceID, "userID"). + Return(false, errors.New("weird error happened")). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "An internal error occurred while processing this request") + }) + + t.Run("when device is remembered, return token", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + mfaManagerMock. + On("MFADeviceRemembered", mock.Anything, deviceID, "userID"). + Return(true, nil). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + require.JSONEq(t, `{"token": "token123"}`, rw.Body.String()) + }) + + t.Run("error generating MFA code", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + mfaManagerMock. + On("MFADeviceRemembered", mock.Anything, deviceID, "userID"). + Return(false, nil). + Once() + mfaManagerMock. + On("GenerateMFACode", mock.Anything, deviceID, "userID"). + Return("", errors.New("some weird error")). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "Cannot get MFA code") + }) + + t.Run("error when code returned is empty", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + mfaManagerMock. + On("MFADeviceRemembered", mock.Anything, deviceID, "userID"). + Return(false, nil). + Once() + mfaManagerMock. + On("GenerateMFACode", mock.Anything, deviceID, "userID"). + Return("", nil). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "Cannot get MFA code") + }) + + t.Run("error sending MFA message", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + mfaManagerMock. + On("MFADeviceRemembered", mock.Anything, deviceID, "userID"). + Return(false, nil). + Once() + mfaManagerMock. + On("GenerateMFACode", mock.Anything, deviceID, "userID"). + Return("123123", nil). + Once() + messengerClientMock. + On("SendMessage", mock.Anything). + Return(errors.New("weird error sending message")). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "An internal error occurred while processing this request") + }) + + t.Run("πŸŽ‰ Successful login", func(t *testing.T) { + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + mfaManagerMock. + On("MFADeviceRemembered", mock.Anything, deviceID, "userID"). + Return(false, nil). + Once() + mfaManagerMock. + On("GenerateMFACode", mock.Anything, deviceID, "userID"). + Return("123123", nil). + Once() + + content, err := htmltemplate.ExecuteHTMLTemplateForMFAMessage(htmltemplate.MFAMessageTemplate{ + OrganizationName: "MyCustomAid", + MFACode: "123123", + }) + require.NoError(t, err) + + msg := message.Message{ + ToEmail: "testuser@mail.com", + Title: mfaMessageTitle, + Message: content, + } + messengerClientMock. + On("SendMessage", msg). + Return(nil). + Once() + + body := LoginRequest{Email: "testuser@mail.com", Password: "pass1234"} + req := httptest.NewRequest(http.MethodPost, "/login", requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + loginHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + require.JSONEq(t, `{"message": "MFA code sent to email. Check your inbox and spam folders."}`, rw.Body.String()) + }) + + authenticatorMock.AssertExpectations(t) + jwtManagerMock.AssertExpectations(t) + roleManagerMock.AssertExpectations(t) + mfaManagerMock.AssertExpectations(t) + messengerClientMock.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/mfa_handler.go b/internal/serve/httphandler/mfa_handler.go new file mode 100644 index 000000000..f9512e3a2 --- /dev/null +++ b/internal/serve/httphandler/mfa_handler.go @@ -0,0 +1,83 @@ +package httphandler + +import ( + "errors" + "net/http" + + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +type MFARequest struct { + MFACode string `json:"mfa_code"` + RememberMe bool `json:"remember_me"` + ReCAPTCHAToken string `json:"recaptcha_token"` +} + +type MFAResponse struct { + Token string `json:"token"` +} + +type MFAHandler struct { + AuthManager auth.AuthManager + ReCAPTCHAValidator validators.ReCAPTCHAValidator + Models *data.Models + ReCAPTCHAEnabled bool +} + +const DeviceIDHeader = "Device-ID" + +func (h MFAHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + var reqBody MFARequest + if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { + log.Ctx(ctx).Errorf("decoding the request body: %s", err.Error()) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + // validating reCAPTCHA Token + if h.ReCAPTCHAEnabled { + isValid, recaptchaErr := h.ReCAPTCHAValidator.IsTokenValid(ctx, reqBody.ReCAPTCHAToken) + if recaptchaErr != nil { + httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", recaptchaErr, nil).Render(rw) + return + } + + if !isValid { + log.Ctx(ctx).Errorf("reCAPTCHA token is invalid for request with email") + httperror.BadRequest("reCAPTCHA token invalid", nil, nil).Render(rw) + return + } + } + + if reqBody.MFACode == "" { + extras := map[string]interface{}{"mfa_code": "MFA Code is required"} + httperror.BadRequest("Request invalid", nil, extras).Render(rw) + return + } + + deviceID := req.Header.Get(DeviceIDHeader) + if deviceID == "" { + httperror.BadRequest("Device-ID header is required", nil, nil).Render(rw) + return + } + + token, err := h.AuthManager.AuthenticateMFA(ctx, deviceID, reqBody.MFACode, reqBody.RememberMe) + if err != nil { + if errors.Is(err, auth.ErrInvalidMFACode) { + httperror.Unauthorized("MFA Code is invalid", err, nil).Render(rw) + return + } + log.Ctx(ctx).Errorf("error authenticating user: %s", err.Error()) + httperror.InternalError(ctx, "Cannot authenticate user", err, nil).Render(rw) + return + } + httpjson.RenderStatus(rw, http.StatusOK, MFAResponse{Token: token}, httpjson.JSON) +} diff --git a/internal/serve/httphandler/mfa_handler_test.go b/internal/serve/httphandler/mfa_handler_test.go new file mode 100644 index 000000000..ce5d01493 --- /dev/null +++ b/internal/serve/httphandler/mfa_handler_test.go @@ -0,0 +1,293 @@ +package httphandler + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const mfaEndpoint = "/mfa" + +func Test_MFAHandler_ServeHTTP(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := data.NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + authenticatorMock := &auth.AuthenticatorMock{} + jwtManagerMock := &auth.JWTManagerMock{} + roleManagerMock := &auth.RoleManagerMock{} + reCAPTCHAValidatorMock := &validators.ReCAPTCHAValidatorMock{} + mfaManagerMock := &auth.MFAManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomRoleManagerOption(roleManagerMock), + auth.WithCustomMFAManagerOption(mfaManagerMock), + ) + + mfaHandler := MFAHandler{ + AuthManager: authManager, + ReCAPTCHAValidator: reCAPTCHAValidatorMock, + Models: models, + ReCAPTCHAEnabled: true, + } + + deviceID := "safari-xyz" + + t.Run("Test handler with invalid body", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, nil) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusBadRequest, rw.Code) + }) + + t.Run("Test handler with unexpected reCAPTCHA error", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(false, errors.New("unexpected error")). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "Cannot validate reCAPTCHA token") + }) + + t.Run("Test handler with invalid reCAPTCHA token", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(false, nil). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusBadRequest, rw.Code) + require.Contains(t, rw.Body.String(), "reCAPTCHA token invalid") + }) + + t.Run("Test Device ID header is empty", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusBadRequest, rw.Code) + require.Contains(t, rw.Body.String(), "Device-ID header is required") + }) + + t.Run("Test MFA code is empty", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + body := MFARequest{ReCAPTCHAToken: "token"} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusBadRequest, rw.Code) + require.Contains(t, rw.Body.String(), "MFA Code is required") + }) + + t.Run("Test MFA code is invalid", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + mfaManagerMock. + On("ValidateMFACode", mock.Anything, deviceID, "123456"). + Return("", auth.ErrMFACodeInvalid). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusUnauthorized, rw.Code) + require.Contains(t, rw.Body.String(), "MFA Code is invalid") + }) + + t.Run("Test MFA validation failed", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + mfaManagerMock. + On("ValidateMFACode", mock.Anything, deviceID, "123456"). + Return("", errors.New("weird error happened")). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token"} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "Cannot authenticate user") + }) + + t.Run("Test MFA remember me failed", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + mfaManagerMock. + On("ValidateMFACode", mock.Anything, deviceID, "123456"). + Return("userID", nil). + Once() + + mfaManagerMock. + On("RememberDevice", mock.Anything, deviceID, "123456"). + Return(errors.New("weird error happened")). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token", RememberMe: true} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "Cannot authenticate user") + }) + + t.Run("Test MFA get user failed", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + mfaManagerMock. + On("ValidateMFACode", mock.Anything, deviceID, "123456"). + Return("userID", nil). + Once() + + mfaManagerMock. + On("RememberDevice", mock.Anything, deviceID, "123456"). + Return(nil). + Once() + + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(nil, errors.New("weird error happened")). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token", RememberMe: true} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusInternalServerError, rw.Code) + require.Contains(t, rw.Body.String(), "Cannot authenticate user") + }) + + t.Run("Test MFA validation successful", func(t *testing.T) { + reCAPTCHAValidatorMock. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + mfaManagerMock. + On("ValidateMFACode", mock.Anything, deviceID, "123456"). + Return("userID", nil). + Once() + + mfaManagerMock. + On("RememberDevice", mock.Anything, deviceID, "123456"). + Return(nil). + Once() + + user := &auth.User{ + ID: "user-id", + Email: "email@email.com", + } + + authenticatorMock. + On("GetUser", mock.Anything, "userID"). + Return(user, nil). + Once() + + roleManagerMock. + On("GetUserRoles", mock.Anything, user). + Return([]string{"role1"}, nil). + Once() + + jwtManagerMock. + On("GenerateToken", mock.Anything, user, mock.AnythingOfType("time.Time")). + Return("token123", nil). + Once() + + body := MFARequest{MFACode: "123456", ReCAPTCHAToken: "token", RememberMe: true} + + req := httptest.NewRequest(http.MethodPost, mfaEndpoint, requestToJSON(t, &body)) + req.Header.Set(DeviceIDHeader, deviceID) + rw := httptest.NewRecorder() + + mfaHandler.ServeHTTP(rw, req) + + require.Equal(t, http.StatusOK, rw.Code) + require.JSONEq(t, `{"token": "token123"}`, rw.Body.String()) + }) + + authenticatorMock.AssertExpectations(t) + reCAPTCHAValidatorMock.AssertExpectations(t) +} + +func requestToJSON(t *testing.T, req interface{}) io.Reader { + body, err := json.Marshal(req) + require.NoError(t, err) + return bytes.NewReader(body) +} diff --git a/internal/serve/httphandler/payments_handler.go b/internal/serve/httphandler/payments_handler.go new file mode 100644 index 000000000..88769ddd2 --- /dev/null +++ b/internal/serve/httphandler/payments_handler.go @@ -0,0 +1,135 @@ +package httphandler + +import ( + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpresponse" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/internal/services" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +type PaymentsHandler struct { + Models *data.Models + DBConnectionPool db.DBConnectionPool + AuthManager auth.AuthManager +} + +type RetryPaymentsRequest struct { + PaymentIDs []string `json:"payment_ids"` +} + +func (r *RetryPaymentsRequest) validate() *httperror.HTTPError { + validator := validators.NewValidator() + validator.Check(len(r.PaymentIDs) != 0, "payment_ids", "payment_ids should not be empty") + if validator.HasErrors() { + return httperror.BadRequest("", nil, validator.Errors) + } + + return nil +} + +func (p PaymentsHandler) GetPayment(w http.ResponseWriter, r *http.Request) { + payment_id := chi.URLParam(r, "id") + + payment, err := p.Models.Payment.Get(r.Context(), payment_id, p.DBConnectionPool.SqlxDB()) + if err != nil { + if errors.Is(data.ErrRecordNotFound, err) { + errorResponse := fmt.Sprintf("Cannot retrieve payment with ID: %s", payment_id) + httperror.NotFound(errorResponse, err, nil).Render(w) + return + } else { + ctx := r.Context() + msg := fmt.Sprintf("Cannot retrieve payment with id %s", payment_id) + httperror.InternalError(ctx, msg, err, nil).Render(w) + return + } + } + + httpjson.RenderStatus(w, http.StatusOK, payment, httpjson.JSON) +} + +func (p PaymentsHandler) GetPayments(w http.ResponseWriter, r *http.Request) { + validator := validators.NewPaymentQueryValidator() + queryParams := validator.ParseParametersFromRequest(r) + var err error + + if validator.HasErrors() { + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + queryParams.Filters = validator.ValidateAndGetPaymentFilters(queryParams.Filters) + if validator.HasErrors() { + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + ctx := r.Context() + + paymentService := services.NewPaymentService(p.Models, p.DBConnectionPool) + response, err := paymentService.GetPaymentsWithCount(ctx, queryParams) + if err != nil { + httperror.InternalError(ctx, "Cannot retrieve payments", err, nil).Render(w) + return + } + if response.TotalPayments == 0 { + httpjson.RenderStatus(w, http.StatusOK, httpresponse.NewEmptyPaginatedResponse(), httpjson.JSON) + } else { + response, errGet := httpresponse.NewPaginatedResponse(r, response.Payments, queryParams.Page, queryParams.PageLimit, response.TotalPayments) + if errGet != nil { + httperror.InternalError(ctx, "Cannot create paginated payments response", errGet, nil).Render(w) + return + } + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) + } +} + +func (p PaymentsHandler) RetryPayments(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + user, err := p.AuthManager.GetUser(ctx, token) + if err != nil { + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + var reqBody RetryPaymentsRequest + if err = httpdecode.DecodeJSON(req, &reqBody); err != nil { + httperror.BadRequest("", err, nil).Render(rw) + return + } + + if err := reqBody.validate(); err != nil { + err.Render(rw) + return + } + + err = p.Models.Payment.RetryFailedPayments(ctx, user.Email, reqBody.PaymentIDs...) + if err != nil { + if errors.Is(err, data.ErrMismatchNumRowsAffected) { + httperror.BadRequest("Invalid payment ID(s) provided. All payment IDs must exist and be in the 'FAILED' state.", err, nil).Render(rw) + return + } + + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"message": "Payments retried successfully"}, httpjson.JSON) +} diff --git a/internal/serve/httphandler/payments_handler_test.go b/internal/serve/httphandler/payments_handler_test.go new file mode 100644 index 000000000..4465822a2 --- /dev/null +++ b/internal/serve/httphandler/payments_handler_test.go @@ -0,0 +1,916 @@ +package httphandler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpresponse" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_PaymentsHandlerGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &PaymentsHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + // setup + r := chi.NewRouter() + r.Get("/payments/{id}", handler.GetPayment) + + ctx := context.Background() + + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + StatusHistory: []data.PaymentStatusHistoryEntry{ + { + Status: data.DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + t.Run("successfully returns payment details for given ID", func(t *testing.T) { + // test + route := fmt.Sprintf("/payments/%s", payment.ID) + req, err := http.NewRequest("GET", route, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := fmt.Sprintf(`{ + "id": %q, + "amount": "50.0000000", + "stellar_transaction_id": %q, + "stellar_operation_id": %q, + "status": "DRAFT", + "status_history": [ + { + "status": "DRAFT", + "status_message": "", + "timestamp": %q + } + ], + "disbursement": { + "id": %q, + "name": "disbursement 1", + "status": "DRAFT", + "created_at": %q, + "updated_at": %q + }, + "asset": { + "id": %q, + "code": "USDC", + "issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "deleted_at": null + }, + "receiver_wallet": { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q + }, + "created_at": %q, + "updated_at": %q + }`, payment.ID, payment.StellarTransactionID, payment.StellarOperationID, payment.StatusHistory[0].Timestamp.Format(time.RFC3339Nano), + disbursement.ID, disbursement.CreatedAt.Format(time.RFC3339Nano), disbursement.UpdatedAt.Format(time.RFC3339Nano), + asset.ID, receiverWallet.ID, receiver.ID, wallet.ID, receiverWallet.StellarAddress, receiverWallet.StellarMemo, + receiverWallet.StellarMemoType, receiverWallet.CreatedAt.Format(time.RFC3339Nano), receiverWallet.UpdatedAt.Format(time.RFC3339Nano), + payment.CreatedAt.Format(time.RFC3339Nano), payment.UpdatedAt.Format(time.RFC3339Nano)) + + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("error payment not found for given ID", func(t *testing.T) { + // test + req, err := http.NewRequest("GET", "/payments/invalid_id", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNotFound, rr.Code) + + wantJson := `{ + "error": "Cannot retrieve payment with ID: invalid_id" + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) +} + +func Test_PaymentHandler_GetPayments_Errors(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &PaymentsHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ts := httptest.NewServer(http.HandlerFunc(handler.GetPayments)) + defer ts.Close() + + tests := []struct { + name string + queryParams map[string]string + expectedStatusCode int + expectedResponse string + }{ + { + name: "returns error when sort parameter is invalid", + queryParams: map[string]string{ + "sort": "invalid_sort", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"sort":"invalid sort field name"}}`, + }, + { + name: "returns error when direction is invalid", + queryParams: map[string]string{ + "direction": "invalid_direction", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"direction":"invalid sort order. valid values are 'asc' and 'desc'"}}`, + }, + { + name: "returns error when page is invalid", + queryParams: map[string]string{ + "page": "invalid_page", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"page":"parameter must be an integer"}}`, + }, + { + name: "returns error when page_limit is invalid", + queryParams: map[string]string{ + "page_limit": "invalid_page_limit", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"page_limit":"parameter must be an integer"}}`, + }, + { + name: "returns error when status is invalid", + queryParams: map[string]string{ + "status": "invalid_status", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"status":"invalid parameter. valid values are: draft, ready, pending, paused, success, failed"}}`, + }, + { + name: "returns error when created_at_after is invalid", + queryParams: map[string]string{ + "created_at_after": "invalid_created_at_after", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"created_at_after":"invalid date format. valid format is 'YYYY-MM-DD'"}}`, + }, + { + name: "returns error when created_at_before is invalid", + queryParams: map[string]string{ + "created_at_before": "invalid_created_at_before", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"created_at_before":"invalid date format. valid format is 'YYYY-MM-DD'"}}`, + }, + { + name: "returns empty list when no expectedPayments are found", + queryParams: map[string]string{}, + expectedStatusCode: http.StatusOK, + expectedResponse: `{"data":[], "pagination":{"pages":0, "total":0}}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build the URL for the test request + url := buildURLWithQueryParams(ts.URL, "/payments", tc.queryParams) + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) + assert.JSONEq(t, tc.expectedResponse, string(respBody)) + }) + } +} + +func Test_PaymentHandler_GetPayments_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &PaymentsHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ts := httptest.NewServer(http.HandlerFunc(handler.GetPayments)) + defer ts.Close() + + ctx := context.Background() + + // create fixtures + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + // create receivers + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.DraftReceiversWalletStatus) + + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.DraftReceiversWalletStatus) + + // create disbursements + disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 2", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + // create payments + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.PendingPaymentStatus, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet1, + CreatedAt: time.Date(2022, 12, 10, 23, 40, 20, 1431, time.UTC), + UpdatedAt: time.Date(2023, 3, 10, 23, 40, 20, 1431, time.UTC), + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "150", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet2, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1431, time.UTC), + UpdatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1431, time.UTC), + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + payment3 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "200.50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet1, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1431, time.UTC), + UpdatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1431, time.UTC), + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + payment4 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "20", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.PendingPaymentStatus, + Disbursement: disbursement2, + Asset: *asset, + ReceiverWallet: receiverWallet2, + CreatedAt: time.Date(2023, 3, 10, 23, 40, 20, 1431, time.UTC), + UpdatedAt: time.Date(2023, 4, 10, 23, 40, 20, 1431, time.UTC), + }) + + tests := []struct { + name string + queryParams map[string]string + expectedStatusCode int + expectedPagination httpresponse.PaginationInfo + expectedPayments []data.Payment + }{ + { + name: "fetch all payments without filters", + queryParams: map[string]string{}, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 4, + }, + expectedPayments: []data.Payment{*payment4, *payment1, *payment3, *payment2}, + }, + { + name: "fetch first page of payments with limit 1 and sort by created_at", + queryParams: map[string]string{ + "page": "1", + "page_limit": "1", + "sort": "created_at", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "/payments?direction=asc&page=2&page_limit=1&sort=created_at", + Prev: "", + Pages: 4, + Total: 4, + }, + expectedPayments: []data.Payment{*payment1}, + }, + { + name: "fetch second page of payments with limit 1 and sort by created_at", + queryParams: map[string]string{ + "page": "2", + "page_limit": "1", + "sort": "created_at", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "/payments?direction=asc&page=3&page_limit=1&sort=created_at", + Prev: "/payments?direction=asc&page=1&page_limit=1&sort=created_at", + Pages: 4, + Total: 4, + }, + expectedPayments: []data.Payment{*payment2}, + }, + { + name: "fetch last page of payments with limit 1 and sort by created_at", + queryParams: map[string]string{ + "page": "4", + "page_limit": "1", + "sort": "created_at", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "/payments?direction=asc&page=3&page_limit=1&sort=created_at", + Pages: 4, + Total: 4, + }, + expectedPayments: []data.Payment{*payment4}, + }, + { + name: "fetch payments with status draft", + queryParams: map[string]string{ + "status": "dRaFt", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 2, + }, + expectedPayments: []data.Payment{*payment3, *payment2}, + }, + { + name: "fetch payments for receiver1", + queryParams: map[string]string{ + "receiver_id": receiver1.ID, + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 2, + }, + expectedPayments: []data.Payment{*payment1, *payment3}, + }, + { + name: "fetch payments for receiver2", + queryParams: map[string]string{ + "receiver_id": receiver2.ID, + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 2, + }, + expectedPayments: []data.Payment{*payment4, *payment2}, + }, + { + name: "returns empty list when receiver_id is not found", + queryParams: map[string]string{ + "receiver_id": "invalid_id", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 0, + Total: 0, + }, + expectedPayments: []data.Payment{}, + }, + { + name: "fetch payments created at before 2023-01-01", + queryParams: map[string]string{ + "created_at_before": "2023-01-01", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 1, + }, + expectedPayments: []data.Payment{*payment1}, + }, + { + name: "fetch payments after 2023-03-01", + queryParams: map[string]string{ + "created_at_after": "2023-03-01", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 1, + }, + expectedPayments: []data.Payment{*payment4}, + }, + { + name: "fetch payment created at after 2023-01-01 and before 2023-03-01", + queryParams: map[string]string{ + "created_at_after": "2023-01-01", + "created_at_before": "2023-03-01", + }, + expectedStatusCode: http.StatusOK, + expectedPagination: httpresponse.PaginationInfo{ + Next: "", + Prev: "", + Pages: 1, + Total: 2, + }, + expectedPayments: []data.Payment{*payment3, *payment2}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build the URL for the test request + url := buildURLWithQueryParams(ts.URL, "/payments", tc.queryParams) + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + // Parse the response body + var actualResponse httpresponse.PaginatedResponse + err = json.NewDecoder(resp.Body).Decode(&actualResponse) + require.NoError(t, err) + + // Assert on the pagination data + assert.Equal(t, tc.expectedPagination, actualResponse.Pagination) + + // Parse the response data + var actualPayments []data.Payment + err = json.Unmarshal(actualResponse.Data, &actualPayments) + require.NoError(t, err) + + // Assert on the payments data + assert.Equal(t, tc.expectedPayments, actualPayments) + }) + } +} + +func Test_PaymentHandler_RetryPayments(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + authManagerMock := &auth.AuthManagerMock{} + + handler := PaymentsHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + AuthManager: authManagerMock, + } + + ctx := context.Background() + + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Country: country, + Wallet: wallet, + Asset: asset, + Status: data.ReadyDisbursementStatus, + VerificationField: data.VerificationFieldDateOfBirth, + }) + + t.Run("returns Unauthorized when no token in the context", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, "/retry", strings.NewReader("{}")) + require.NoError(t, err) + + rw := httptest.NewRecorder() + http.HandlerFunc(handler.RetryPayments).ServeHTTP(rw, req) + + resp := rw.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + }) + + t.Run("returns InternalServerError when fails getting user from token", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, "/retry", strings.NewReader("{}")) + require.NoError(t, err) + + authManagerMock. + On("GetUser", ctx, "mytoken"). + Return(nil, errors.New("unexpected error")). + Once() + + rw := httptest.NewRecorder() + http.HandlerFunc(handler.RetryPayments).ServeHTTP(rw, req) + + resp := rw.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "An internal error occurred while processing this request."}`, string(respBody)) + }) + + t.Run("returns BadRequest when fails decoding body request", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + payload := strings.NewReader("invalid") + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, "/retry", payload) + require.NoError(t, err) + + authManagerMock. + On("GetUser", ctx, "mytoken"). + Return(&auth.User{ + Email: "email@test.com", + }, nil). + Once() + + rw := httptest.NewRecorder() + http.HandlerFunc(handler.RetryPayments).ServeHTTP(rw, req) + + resp := rw.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way."}`, string(respBody)) + }) + + t.Run("returns BadRequest when fails when payload is invalid", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + payload := strings.NewReader("{}") + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, "/retry", payload) + require.NoError(t, err) + + authManagerMock. + On("GetUser", ctx, "mytoken"). + Return(&auth.User{ + Email: "email@test.com", + }, nil). + Once() + + rw := httptest.NewRecorder() + http.HandlerFunc(handler.RetryPayments).ServeHTTP(rw, req) + + resp := rw.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"payment_ids": "payment_ids should not be empty"}}`, string(respBody)) + }) + + t.Run("returns BadRequest when some payments IDs are not in the failed state", func(t *testing.T) { + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: data.PendingPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: data.ReadyPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment3 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-3", + StellarOperationID: "operation-id-3", + Status: data.FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment4 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-4", + StellarOperationID: "operation-id-4", + Status: data.FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + payload := strings.NewReader(fmt.Sprintf(` + { + "payment_ids": [ + %q, + %q, + %q, + %q + ] + } + `, payment1.ID, payment2.ID, payment3.ID, payment4.ID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, "/retry", payload) + require.NoError(t, err) + + authManagerMock. + On("GetUser", ctx, "mytoken"). + Return(&auth.User{ + Email: "email@test.com", + }, nil). + Once() + + rw := httptest.NewRecorder() + http.HandlerFunc(handler.RetryPayments).ServeHTTP(rw, req) + + resp := rw.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "Invalid payment ID(s) provided. All payment IDs must exist and be in the 'FAILED' state."}`, string(respBody)) + + payment1DB, err := models.Payment.Get(ctx, payment1.ID, dbConnectionPool) + require.NoError(t, err) + + payment2DB, err := models.Payment.Get(ctx, payment2.ID, dbConnectionPool) + require.NoError(t, err) + + payment3DB, err := models.Payment.Get(ctx, payment3.ID, dbConnectionPool) + require.NoError(t, err) + + payment4DB, err := models.Payment.Get(ctx, payment4.ID, dbConnectionPool) + require.NoError(t, err) + + // Payment 1 + assert.Equal(t, data.PendingPaymentStatus, payment1DB.Status) + assert.Equal(t, payment1.StellarTransactionID, payment1DB.StellarTransactionID) + assert.Equal(t, payment1.StatusHistory, payment1DB.StatusHistory) + + // Payment 2 + assert.Equal(t, data.ReadyPaymentStatus, payment2DB.Status) + assert.Equal(t, payment2.StellarTransactionID, payment2DB.StellarTransactionID) + assert.Equal(t, payment2.StatusHistory, payment2DB.StatusHistory) + + // Payment 3 + assert.Equal(t, data.FailedPaymentStatus, payment3DB.Status) + assert.Equal(t, payment3.StellarTransactionID, payment3DB.StellarTransactionID) + assert.Equal(t, payment3.StatusHistory, payment3DB.StatusHistory) + + // Payment 4 + assert.Equal(t, data.FailedPaymentStatus, payment4DB.Status) + assert.Equal(t, payment4.StellarTransactionID, payment4DB.StellarTransactionID) + assert.Equal(t, payment4.StatusHistory, payment4DB.StatusHistory) + }) + + t.Run("successfully retries failed payments", func(t *testing.T) { + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: data.FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "1", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: data.FailedPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + payload := strings.NewReader(fmt.Sprintf(` + { + "payment_ids": [ + %q, + %q + ] + } + `, payment1.ID, payment2.ID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, "/retry", payload) + require.NoError(t, err) + + authManagerMock. + On("GetUser", ctx, "mytoken"). + Return(&auth.User{ + Email: "email@test.com", + }, nil). + Once() + + rw := httptest.NewRecorder() + http.HandlerFunc(handler.RetryPayments).ServeHTTP(rw, req) + + resp := rw.Result() + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "Payments retried successfully"}`, string(respBody)) + + payment1DB, err := models.Payment.Get(ctx, payment1.ID, dbConnectionPool) + require.NoError(t, err) + + payment2DB, err := models.Payment.Get(ctx, payment2.ID, dbConnectionPool) + require.NoError(t, err) + + // Payment 1 + assert.Equal(t, data.ReadyPaymentStatus, payment1DB.Status) + assert.Empty(t, payment1DB.StellarTransactionID) + assert.NotEqual(t, payment1.StatusHistory, payment1DB.StatusHistory) + assert.Len(t, payment1DB.StatusHistory, 2) + assert.Equal(t, data.ReadyPaymentStatus, payment1DB.StatusHistory[1].Status) + assert.Equal(t, "User email@test.com has requested to retry the payment - Previous Stellar Transaction ID: stellar-transaction-id-1", payment1DB.StatusHistory[1].StatusMessage) + + // Payment 2 + assert.Equal(t, data.ReadyPaymentStatus, payment2DB.Status) + assert.Empty(t, payment2DB.StellarTransactionID) + assert.NotEqual(t, payment2.StatusHistory, payment2DB.StatusHistory) + assert.Len(t, payment2DB.StatusHistory, 2) + assert.Equal(t, data.ReadyPaymentStatus, payment2DB.StatusHistory[1].Status) + assert.Equal(t, "User email@test.com has requested to retry the payment - Previous Stellar Transaction ID: stellar-transaction-id-2", payment2DB.StatusHistory[1].StatusMessage) + }) +} diff --git a/internal/serve/httphandler/profile_handler.go b/internal/serve/httphandler/profile_handler.go new file mode 100644 index 000000000..15204cf79 --- /dev/null +++ b/internal/serve/httphandler/profile_handler.go @@ -0,0 +1,322 @@ +package httphandler + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "image" + + // Don't remove the `image/jpeg` and `image/png` packages import unless + // the `image` package is no longer necessary. + // It registers the `Decoders` to handle the image decoding - `image.Decode`. + // See https://pkg.go.dev/image#pkg-overview + _ "image/jpeg" + _ "image/png" + "io" + "io/fs" + "net/http" + "net/url" + "strings" + + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +// DefaultMaxMemoryAllocation limits the max of memory allocation up to 2MB +// when parsing the multipart form data request +const DefaultMaxMemoryAllocation = 2 * 1024 * 1024 + +type ProfileHandler struct { + Models *data.Models + AuthManager auth.AuthManager + MaxMemoryAllocation int64 + BaseURL string + PublicFilesFS fs.FS + DistributionPublicKey string +} + +type PatchOrganizationProfileRequest struct { + OrganizationName string `json:"organization_name"` + TimezoneUTCOffset string `json:"timezone_utc_offset"` + IsApprovalRequired *bool `json:"is_approval_required"` +} + +func (r *PatchOrganizationProfileRequest) AreAllFieldsEmpty() bool { + return r.OrganizationName == "" && r.TimezoneUTCOffset == "" && r.IsApprovalRequired == nil +} + +type PatchUserProfileRequest struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + Password string `json:"password"` +} + +type GetProfileResponse struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + Roles []string `json:"roles"` + OrganizationName string `json:"organization_name"` +} + +func (h ProfileHandler) PatchOrganizationProfile(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + _, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + // limiting the size of the request + req.Body = http.MaxBytesReader(rw, req.Body, h.MaxMemoryAllocation) + + // limiting the amount of memory allocated in the server to handle the request + if err := req.ParseMultipartForm(h.MaxMemoryAllocation); err != nil { + err = fmt.Errorf("error parsing multipart form: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("could not parse multipart form data", err, map[string]interface{}{ + "details": "request too large. Max size 2MB.", + }).Render(rw) + return + } + + multipartFile, _, err := req.FormFile("logo") + if err != nil && !errors.Is(err, http.ErrMissingFile) { + err = fmt.Errorf("error parsing logo file: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("could not parse request logo", err, nil).Render(rw) + return + } + + var fileContentBytes []byte + // a file is present in the request + if multipartFile != nil { + fileContentBytes, err = io.ReadAll(multipartFile) + if err != nil { + httperror.InternalError(ctx, "Cannot read file contents", err, nil).Render(rw) + return + } + + // We need to ensure the the type of file is one of the accepted - image/png and image/jpeg + fileContentType := http.DetectContentType(fileContentBytes) + + validator := validators.NewValidator() + expectedContentTypes := fmt.Sprintf("%s %s", data.PNGLogoType.ToHTTPContentType(), data.JPEGLogoType.ToHTTPContentType()) + validator.Check(strings.Contains(expectedContentTypes, fileContentType), "logo", "invalid file type provided. Expected png or jpeg.") + if validator.HasErrors() { + httperror.BadRequest("", nil, validator.Errors).Render(rw) + return + } + } + + var reqBody PatchOrganizationProfileRequest + d := req.FormValue("data") + if err = json.Unmarshal([]byte(d), &reqBody); err != nil { + err = fmt.Errorf("error decoding data: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + // validate wether the logo or the organization_name were sent in the request + if len(fileContentBytes) == 0 && reqBody.AreAllFieldsEmpty() { + httperror.BadRequest("request is invalid", nil, map[string]interface{}{ + "details": "data or logo is required", + }).Render(rw) + return + } + + err = h.Models.Organizations.Update(ctx, &data.OrganizationUpdate{ + Name: reqBody.OrganizationName, + Logo: fileContentBytes, + TimezoneUTCOffset: reqBody.TimezoneUTCOffset, + IsApprovalRequired: reqBody.IsApprovalRequired, + }) + if err != nil { + httperror.InternalError(ctx, "Cannot update organization", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"message": "organization profile updated successfully"}, httpjson.JSON) +} + +func (h ProfileHandler) PatchUserProfile(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + var reqBody PatchUserProfileRequest + if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { + err = fmt.Errorf("decoding the request body: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + if reqBody.Password != "" && len(reqBody.Password) < 8 { + httperror.BadRequest("", nil, map[string]interface{}{ + "password": "password should have at least 8 characters", + }).Render(rw) + return + } + + if reqBody.Email != "" { + if err := utils.ValidateEmail(reqBody.Email); err != nil { + httperror.BadRequest("", nil, map[string]interface{}{ + "email": "invalid email provided", + }).Render(rw) + return + } + } + + if reqBody.FirstName == "" && reqBody.LastName == "" && reqBody.Email == "" && reqBody.Password == "" { + httperror.BadRequest("", nil, map[string]interface{}{ + "details": "provide at least first_name, last_name, email or password.", + }).Render(rw) + return + } + + err := h.AuthManager.UpdateUser(ctx, token, reqBody.FirstName, reqBody.LastName, reqBody.Email, reqBody.Password) + if err != nil { + httperror.InternalError(ctx, "Cannot update user profiles", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"message": "user profile updated successfully"}, httpjson.JSON) +} + +func (h ProfileHandler) GetProfile(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + user, err := h.AuthManager.GetUser(ctx, token) + if err != nil { + if errors.Is(err, auth.ErrInvalidToken) { + err = fmt.Errorf("getting user profile: %w", err) + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(rw) + return + } + + if errors.Is(err, auth.ErrUserNotFound) { + err = fmt.Errorf("user from token %s not found: %w", token, err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + httperror.InternalError(ctx, "Cannot get user", err, nil).Render(rw) + return + } + + org, err := h.Models.Organizations.Get(ctx) + if err != nil { + httperror.InternalError(ctx, "Cannot get organization", err, nil).Render(rw) + return + } + + resp := &GetProfileResponse{ + FirstName: user.FirstName, + LastName: user.LastName, + Email: user.Email, + Roles: user.Roles, + OrganizationName: org.Name, + } + httpjson.RenderStatus(rw, http.StatusOK, resp, httpjson.JSON) +} + +func (h ProfileHandler) GetOrganizationInfo(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + // We first build the logo URL so we don't hit the database if any error occurs. + logoURL, err := url.JoinPath(h.BaseURL, "organization", "logo") + if err != nil { + httperror.InternalError(ctx, "Cannot get logo URL", err, nil).Render(rw) + return + } + + lu, err := url.Parse(logoURL) + if err != nil { + httperror.InternalError(ctx, "Cannot parse logo URL", err, nil).Render(rw) + return + } + + q := lu.Query() + q.Add("token", token) + lu.RawQuery = q.Encode() + + org, err := h.Models.Organizations.Get(ctx) + if err != nil { + httperror.InternalError(ctx, "Cannot get organization", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]interface{}{ + "name": org.Name, + "logo_url": lu.String(), + "distribution_account_public_key": h.DistributionPublicKey, + "timezone_utc_offset": org.TimezoneUTCOffset, + "is_approval_required": org.IsApprovalRequired, + }, httpjson.JSON) +} + +// GetOrganizationLogo renders the stored organization logo. The image is rendered inline (not attached - the attached option downloads the content) +// so the client can embed the image. +func (h ProfileHandler) GetOrganizationLogo(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + org, err := h.Models.Organizations.Get(ctx) + if err != nil { + httperror.InternalError(ctx, "Cannot get organization", err, nil).Render(rw) + return + } + + if len(org.Logo) == 0 { + var logoBytes []byte + logoBytes, err = fs.ReadFile(h.PublicFilesFS, "img/logo.png") + if err != nil { + httperror.InternalError(ctx, "Cannot open default logo", err, nil).Render(rw) + return + } + + org.Logo = logoBytes + } + + _, ext, err := image.Decode(bytes.NewReader(org.Logo)) + if err != nil { + httperror.InternalError(ctx, "Cannot decode organization logo", err, nil).Render(rw) + return + } + + rw.Header().Set("Content-Disposition", fmt.Sprintf(`inline; filename="%s"`, fmt.Sprintf("logo.%s", ext))) + rw.Header().Set("Content-Type", http.DetectContentType(org.Logo)) + _, err = rw.Write(org.Logo) + if err != nil { + httperror.InternalError(ctx, "Cannot write organization logo to response", err, nil).Render(rw) + } +} diff --git a/internal/serve/httphandler/profile_handler_test.go b/internal/serve/httphandler/profile_handler_test.go new file mode 100644 index 000000000..f515847da --- /dev/null +++ b/internal/serve/httphandler/profile_handler_test.go @@ -0,0 +1,988 @@ +package httphandler + +import ( + "bytes" + "context" + "encoding/csv" + "errors" + "fmt" + "image/jpeg" + "image/png" + "io" + "io/fs" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/publicfiles" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createOrganizationProfileMultipartRequest(t *testing.T, url, fieldName, filename, body string, fileContent io.Reader) (*http.Request, error) { + buf := new(bytes.Buffer) + writer := multipart.NewWriter(buf) + defer writer.Close() + + if fieldName == "" { + fieldName = "logo" + } + + part, err := writer.CreateFormFile(fieldName, filename) + require.NoError(t, err) + + _, err = io.Copy(part, fileContent) + require.NoError(t, err) + + // adding the data + err = writer.WriteField("data", body) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPatch, url, buf) + require.NoError(t, err) + req.Header.Set("Content-Type", writer.FormDataContentType()) + return req, nil +} + +func Test_PatchOrganizationProfileRequest_AreAllFieldsEmpty(t *testing.T) { + r := &PatchOrganizationProfileRequest{ + OrganizationName: "", + TimezoneUTCOffset: "", + } + res := r.AreAllFieldsEmpty() + assert.True(t, res) + + r = &PatchOrganizationProfileRequest{ + OrganizationName: "MyAid", + TimezoneUTCOffset: "", + } + res = r.AreAllFieldsEmpty() + assert.False(t, res) + + r = &PatchOrganizationProfileRequest{ + OrganizationName: "", + TimezoneUTCOffset: "-03:00", + } + res = r.AreAllFieldsEmpty() + assert.False(t, res) +} + +func Test_ProfileHandler_PatchOrganizationProfile(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &ProfileHandler{Models: models, MaxMemoryAllocation: DefaultMaxMemoryAllocation} + url := "/profile/organization" + + resetOrganizationInfo := func(t *testing.T, ctx context.Context) { + const q = "UPDATE organizations SET name = 'MyCustomAid', logo = NULL, timezone_utc_offset = '+00:00'" + _, err := dbConnectionPool.ExecContext(ctx, q) + require.NoError(t, err) + } + + ctx := context.Background() + + t.Run("returns Unauthorized error when no token is found", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + }) + + t.Run("returns BadRequest error when the request is invalid", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + // Invalid JSON data + img := data.CreateMockImage(t, 300, 300, data.ImageSizeSmall) + imgBuf := new(bytes.Buffer) + err := png.Encode(imgBuf, img) + require.NoError(t, err) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "logo", "logo.png", `invalid`, imgBuf) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way."}`, string(respBody)) + + // Invalid file format + csvBuf := new(bytes.Buffer) + csvWriter := csv.NewWriter(csvBuf) + err = csvWriter.WriteAll([][]string{ + {"name", "age"}, + {"foo", "99"}, + {"bar", "99"}, + }) + require.NoError(t, err) + + w = httptest.NewRecorder() + req, err = createOrganizationProfileMultipartRequest(t, url, "logo", "logo.csv", `{}`, csvBuf) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "The request was invalid in some way.", + "extras": { + "logo": "invalid file type provided. Expected png or jpeg." + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + // Neither logo and organization_name isn't present. + w = httptest.NewRecorder() + req, err = createOrganizationProfileMultipartRequest(t, url, "wrong", "logo.png", `{}`, new(bytes.Buffer)) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "request is invalid", "extras": {"details": "data or logo is required"}}`, string(respBody)) + }) + + t.Run("returns BadRequest error when the request size is too large", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + img := data.CreateMockImage(t, 3840, 2160, data.ImageSizeMedium) + imgBuf := new(bytes.Buffer) + err := jpeg.Encode(imgBuf, img, &jpeg.Options{Quality: jpeg.DefaultQuality}) + require.NoError(t, err) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "logo", "logo.jpeg", `{}`, imgBuf) + require.NoError(t, err) + + req = req.WithContext(ctx) + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + profileHandler := &ProfileHandler{Models: models, MaxMemoryAllocation: 1024 * 1024} + http.HandlerFunc(profileHandler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "could not parse multipart form data", "extras": {"details": "request too large. Max size 2MB."}}`, string(respBody)) + + entries := getEntries() + assert.Equal(t, "error parsing multipart form: http: request body too large", entries[0].Message) + }) + + t.Run("updates the organization's name successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + org, err := models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "MyCustomAid", org.Name) + assert.Nil(t, org.Logo) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "", "", `{"organization_name": "My Org Name"}`, new(bytes.Buffer)) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "organization profile updated successfully"}`, string(respBody)) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "My Org Name", org.Name) + assert.Nil(t, org.Logo) + }) + + t.Run("updates the organization's timezone UTC offset successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + org, err := models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "+00:00", org.TimezoneUTCOffset) + assert.Equal(t, "MyCustomAid", org.Name) + assert.Nil(t, org.Logo) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "", "", `{"timezone_utc_offset": "-03:00"}`, new(bytes.Buffer)) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "organization profile updated successfully"}`, string(respBody)) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "-03:00", org.TimezoneUTCOffset) + assert.Equal(t, "MyCustomAid", org.Name) + assert.Nil(t, org.Logo) + }) + + t.Run("updates the organization's IsApprovalRequired successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + org, err := models.Organizations.Get(ctx) + require.NoError(t, err) + assert.False(t, org.IsApprovalRequired) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "", "", `{"is_approval_required": true}`, new(bytes.Buffer)) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "organization profile updated successfully"}`, string(respBody)) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + require.True(t, org.IsApprovalRequired) + }) + + t.Run("updates the organization's logo successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + // PNG logo + org, err := models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Nil(t, org.Logo) + assert.Equal(t, "MyCustomAid", org.Name) + + img := data.CreateMockImage(t, 300, 300, data.ImageSizeSmall) + imgBuf := new(bytes.Buffer) + err = png.Encode(imgBuf, img) + require.NoError(t, err) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "logo", "logo.png", `{}`, imgBuf) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "organization profile updated successfully"}`, string(respBody)) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + + // renew buffer + imgBuf = new(bytes.Buffer) + err = png.Encode(imgBuf, img) + require.NoError(t, err) + + assert.Equal(t, imgBuf.Bytes(), org.Logo) + assert.Equal(t, "MyCustomAid", org.Name) + + // JPEG logo + resetOrganizationInfo(t, ctx) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Nil(t, org.Logo) + assert.Equal(t, "MyCustomAid", org.Name) + + img = data.CreateMockImage(t, 300, 300, data.ImageSizeSmall) + imgBuf = new(bytes.Buffer) + err = jpeg.Encode(imgBuf, img, &jpeg.Options{Quality: jpeg.DefaultQuality}) + require.NoError(t, err) + + w = httptest.NewRecorder() + req, err = createOrganizationProfileMultipartRequest(t, url, "logo", "logo.jpeg", `{}`, imgBuf) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "organization profile updated successfully"}`, string(respBody)) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + + // renew buffer + imgBuf = new(bytes.Buffer) + err = jpeg.Encode(imgBuf, img, &jpeg.Options{Quality: jpeg.DefaultQuality}) + require.NoError(t, err) + + assert.Equal(t, imgBuf.Bytes(), org.Logo) + assert.Equal(t, "MyCustomAid", org.Name) + }) + + t.Run("updates both organization name, timezone UTC offset and logo successfully", func(t *testing.T) { + resetOrganizationInfo(t, ctx) + + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + org, err := models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, "MyCustomAid", org.Name) + assert.Equal(t, "+00:00", org.TimezoneUTCOffset) + assert.Nil(t, org.Logo) + + img := data.CreateMockImage(t, 300, 300, data.ImageSizeSmall) + imgBuf := new(bytes.Buffer) + err = png.Encode(imgBuf, img) + require.NoError(t, err) + + w := httptest.NewRecorder() + req, err := createOrganizationProfileMultipartRequest(t, url, "logo", "logo.png", `{"organization_name": "My Org Name", "timezone_utc_offset": "-03:00"}`, imgBuf) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchOrganizationProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "organization profile updated successfully"}`, string(respBody)) + + org, err = models.Organizations.Get(ctx) + require.NoError(t, err) + + // renew buffer + imgBuf = new(bytes.Buffer) + err = png.Encode(imgBuf, img) + require.NoError(t, err) + + assert.Equal(t, "My Org Name", org.Name) + assert.Equal(t, "-03:00", org.TimezoneUTCOffset) + assert.Equal(t, imgBuf.Bytes(), org.Logo) + }) +} + +func Test_ProfileHandler_PatchUserProfile(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + authenticatorMock := &auth.AuthenticatorMock{} + jwtManagerMock := &auth.JWTManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + auth.WithCustomJWTManagerOption(jwtManagerMock), + ) + + handler := &ProfileHandler{AuthManager: authManager} + url := "/profile/user" + + ctx := context.Background() + + t.Run("returns Unauthorized error when no token is found", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + }) + + t.Run("returns BadRequest error when the request is invalid", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + // Invalid JSON + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`invalid`)) + require.NoError(t, err) + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way."}`, string(respBody)) + + // Invalid email + w = httptest.NewRecorder() + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{"email": "invalid"}`)) + require.NoError(t, err) + + req = req.WithContext(ctx) + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"email": "invalid email provided"}}`, string(respBody)) + + // Password too short + w = httptest.NewRecorder() + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{"password": "short"}`)) + require.NoError(t, err) + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"password": "password should have at least 8 characters"}}`, string(respBody)) + + // None of values provided + w = httptest.NewRecorder() + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{}`)) + require.NoError(t, err) + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"details":"provide at least first_name, last_name, email or password."}}`, string(respBody)) + }) + + t.Run("returns InternalServerError when AuthManager fails", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + reqBody := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "password": "mypassword" + } + ` + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), "token"). + Return(true, nil). + Once(). + On("GetUserFromToken", req.Context(), "token"). + Return(&auth.User{ID: "user-id"}, nil). + Once() + + authenticatorMock. + On("UpdateUser", req.Context(), "user-id", "First", "Last", "email@email.com", "mypassword"). + Return(errors.New("unexpected error")). + Once() + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error":"Cannot update user profiles"}`, string(respBody)) + }) + + t.Run("updates the user profile successfully", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "token") + + reqBody := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "password": "mypassword" + } + ` + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), "token"). + Return(true, nil). + Once(). + On("GetUserFromToken", req.Context(), "token"). + Return(&auth.User{ID: "user-id"}, nil). + Once() + + authenticatorMock. + On("UpdateUser", req.Context(), "user-id", "First", "Last", "email@email.com", "mypassword"). + Return(nil). + Once() + + http.HandlerFunc(handler.PatchUserProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "user profile updated successfully"}`, string(respBody)) + }) + + authenticatorMock.AssertExpectations(t) + jwtManagerMock.AssertExpectations(t) +} + +func Test_ProfileHandler_GetProfile(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + authManagerMock := &auth.AuthManagerMock{} + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &ProfileHandler{Models: models, AuthManager: authManagerMock} + url := "/profile" + + ctx := context.Background() + + t.Run("returns Unauthorized error when no token is found", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.GetProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + }) + + t.Run("returns Unauthorized when AuthManager fails with ErrInvalidToken", func(t *testing.T) { + token := "mytoken" + ctx = context.WithValue(ctx, middleware.TokenContextKey, token) + + expectedErr := auth.ErrInvalidToken + authManagerMock. + On("GetUser", ctx, token). + Return(nil, expectedErr). + Once() + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + http.HandlerFunc(handler.GetProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + + entries := getEntries() + expectedLog := fmt.Sprintf("getting user profile: %s", expectedErr) + assert.Equal(t, expectedLog, entries[0].Message) + }) + + t.Run("returns BadRequest when user is not found", func(t *testing.T) { + token := "mytoken" + ctx = context.WithValue(ctx, middleware.TokenContextKey, token) + expectedErr := fmt.Errorf("error getting user ID %s: %w", "user-id", auth.ErrUserNotFound) + + authManagerMock. + On("GetUser", ctx, token). + Return(nil, expectedErr). + Once() + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + http.HandlerFunc(handler.GetProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way."}`, string(respBody)) + + entries := getEntries() + expectedLog := fmt.Sprintf("user from token mytoken not found: %s", expectedErr) + assert.Equal(t, expectedLog, entries[0].Message) + }) + + t.Run("returns InternalServerError when AuthManager fails", func(t *testing.T) { + token := "mytoken" + ctx = context.WithValue(ctx, middleware.TokenContextKey, token) + + expectedErr := errors.New("error getting user ID user-id: unexpected error") + authManagerMock. + On("GetUser", ctx, token). + Return(nil, expectedErr). + Once() + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + http.HandlerFunc(handler.GetProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot get user"}`, string(respBody)) + + entries := getEntries() + expectedLog := fmt.Sprintf("Cannot get user: %s", expectedErr) + assert.Equal(t, expectedLog, entries[0].Message) + }) + + t.Run("returns the profile info successfully", func(t *testing.T) { + token := "mytoken" + ctx = context.WithValue(ctx, middleware.TokenContextKey, token) + + u := &auth.User{ + ID: "user-id", + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + authManagerMock. + On("GetUser", ctx, token). + Return(u, nil). + Once() + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.GetProfile).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "email": "email@email.com", + "first_name": "First", + "last_name": "Last", + "organization_name": "MyCustomAid", + "roles": ["developer"] + } + ` + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + authManagerMock.AssertExpectations(t) +} + +func Test_ProfileHandler_GetOrganizationInfo(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + distributionAccountPK := keypair.MustRandom().Address() + handler := &ProfileHandler{Models: models, BaseURL: "http://localhost:8000", DistributionPublicKey: distributionAccountPK} + url := "/profile/info" + + ctx := context.Background() + + t.Run("returns Unauthorized error when no token is found", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.GetOrganizationInfo).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + }) + + t.Run("returns InternalServerError if getting logo URL fails", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + h := &ProfileHandler{Models: models, BaseURL: "%invalid%"} + http.HandlerFunc(h.GetOrganizationInfo).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot get logo URL"}`, string(respBody)) + + entries := getEntries() + assert.Equal(t, `Cannot get logo URL: parse "%invalid%": invalid URL escape "%in"`, entries[0].Message) + }) + + t.Run("returns the organization info successfully", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.GetOrganizationInfo).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := fmt.Sprintf(` + { + "logo_url": "http://localhost:8000/organization/logo?token=mytoken", + "name": "MyCustomAid", + "distribution_account_public_key": %q, + "timezone_utc_offset": "+00:00", + "is_approval_required":false + } + `, distributionAccountPK) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) +} + +func Test_ProfileHandler_GetOrganizationLogo(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := data.NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + handler := &ProfileHandler{Models: models, PublicFilesFS: publicfiles.PublicFiles} + url := "/organization/logo" + + ctx := context.Background() + + t.Run("returns InternalServerError when can't find the default logo file", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + fsMap := fstest.MapFS{} + h := &ProfileHandler{Models: models, PublicFilesFS: fsMap} + http.HandlerFunc(h.GetOrganizationLogo).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot open default logo"}`, string(respBody)) + + entries := getEntries() + assert.NotEmpty(t, entries) + assert.Equal(t, `Cannot open default logo: open img/logo.png: file does not exist`, entries[0].Message) + }) + + t.Run("returns the default logo when no logo is set", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.GetOrganizationLogo).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedLogoBytes, err := fs.ReadFile(publicfiles.PublicFiles, "img/logo.png") + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, expectedLogoBytes, respBody) + }) + + t.Run("returns the organization logo stored in the database successfully", func(t *testing.T) { + imgBuf := new(bytes.Buffer) + img := data.CreateMockImage(t, 300, 300, data.ImageSizeSmall) + err := png.Encode(imgBuf, img) + require.NoError(t, err) + + err = models.Organizations.Update(ctx, &data.OrganizationUpdate{Logo: imgBuf.Bytes()}) + require.NoError(t, err) + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.GetOrganizationLogo).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + org, err := models.Organizations.Get(ctx) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, org.Logo, respBody) + }) +} diff --git a/internal/serve/httphandler/receiver_handler.go b/internal/serve/httphandler/receiver_handler.go new file mode 100644 index 000000000..87d1e17a7 --- /dev/null +++ b/internal/serve/httphandler/receiver_handler.go @@ -0,0 +1,128 @@ +package httphandler + +import ( + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpresponse" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" +) + +type ReceiverHandler struct { + Models *data.Models + DBConnectionPool db.DBConnectionPool +} + +type GetReceiverResponse struct { + data.Receiver + Wallets []data.ReceiverWallet `json:"wallets"` +} + +func (rh ReceiverHandler) buildReceiversResponse(receivers []data.Receiver, receiversWallets []data.ReceiverWallet) []GetReceiverResponse { + var responses []GetReceiverResponse + + for _, receiver := range receivers { + wallets := make([]data.ReceiverWallet, 0) + for _, wallet := range receiversWallets { + if wallet.Receiver.ID == receiver.ID { + wallets = append(wallets, wallet) + } + } + responses = append(responses, GetReceiverResponse{ + Receiver: receiver, + Wallets: wallets, + }) + } + + return responses +} + +func (rh ReceiverHandler) GetReceiver(w http.ResponseWriter, r *http.Request) { + receiverID := chi.URLParam(r, "id") + ctx := r.Context() + + response, err := db.RunInTransactionWithResult(ctx, rh.DBConnectionPool, nil, func(dbTx db.DBTransaction) (response *GetReceiverResponse, innerErr error) { + receiver, innerErr := rh.Models.Receiver.Get(ctx, dbTx, receiverID) + if innerErr != nil { + return nil, fmt.Errorf("getting receiver by ID: %w", innerErr) + } + + receiverWallets, innerErr := rh.Models.ReceiverWallet.GetWithReceiverIds(ctx, dbTx, data.ReceiverIDs{receiver.ID}) + if innerErr != nil { + return nil, fmt.Errorf("getting receiver wallets with receiver IDs: %w", innerErr) + } + + return &GetReceiverResponse{ + Receiver: *receiver, + Wallets: receiverWallets, + }, nil + }) + if err != nil { + if errors.Is(err, data.ErrRecordNotFound) { + errorResponse := fmt.Sprintf("could not retrieve receiver with ID: %s", receiverID) + httperror.NotFound(errorResponse, err, nil).Render(w) + } else { + msg := fmt.Sprintf("Cannot retrieve receiver with ID %s", receiverID) + httperror.InternalError(ctx, msg, err, nil).Render(w) + } + return + } + + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) +} + +func (rh ReceiverHandler) GetReceivers(w http.ResponseWriter, r *http.Request) { + validator := validators.NewReceiverQueryValidator() + + queryParams := validator.ParseParametersFromRequest(r) + queryParams.Filters = validator.ValidateAndGetReceiverFilters(queryParams.Filters) + if validator.HasErrors() { + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + ctx := r.Context() + + httpResponse, err := db.RunInTransactionWithResult(ctx, rh.DBConnectionPool, nil, func(dbTx db.DBTransaction) (*httpresponse.PaginatedResponse, error) { + totalReceivers, err := rh.Models.Receiver.Count(ctx, dbTx, queryParams) + if err != nil { + return nil, fmt.Errorf("error retrieving receivers count: %w", err) + } + + if totalReceivers == 0 { + httpResponse := httpresponse.NewEmptyPaginatedResponse() + return &httpResponse, nil + } + + receivers, err := rh.Models.Receiver.GetAll(ctx, dbTx, queryParams) + if err != nil { + return nil, fmt.Errorf("error retrieving receivers: %w", err) + } + + receiverIDs := rh.Models.Receiver.ParseReceiverIDs(receivers) + receiversWallets, err := rh.Models.ReceiverWallet.GetWithReceiverIds(ctx, dbTx, receiverIDs) + if err != nil { + return nil, fmt.Errorf("error retrieving receiver wallets: %w", err) + } + + receiversResponse := rh.buildReceiversResponse(receivers, receiversWallets) + httpResponse, err := httpresponse.NewPaginatedResponse(r, receiversResponse, queryParams.Page, queryParams.PageLimit, totalReceivers) + if err != nil { + return nil, fmt.Errorf("error creating paginated response for receivers: %w", err) + } + + return &httpResponse, nil + }) + if err != nil { + httperror.InternalError(ctx, "Cannot retrieve receivers", err, nil).Render(w) + return + } + + httpjson.RenderStatus(w, http.StatusOK, httpResponse, httpjson.JSON) +} diff --git a/internal/serve/httphandler/receiver_handler_test.go b/internal/serve/httphandler/receiver_handler_test.go new file mode 100644 index 000000000..75ed52ba8 --- /dev/null +++ b/internal/serve/httphandler/receiver_handler_test.go @@ -0,0 +1,1543 @@ +package httphandler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReceiverHandlerGet(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &ReceiverHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + // setup + r := chi.NewRouter() + r.Get("/receivers/{id}", handler.GetReceiver) + + ctx := context.Background() + + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet1 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet1.com", "www.wallet1.com", "wallet1://") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + disbursement := data.Disbursement{ + Status: data.DraftDisbursementStatus, + Asset: asset, + Country: country, + } + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + payment := data.Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Asset: *asset, + } + + t.Run("successfully returns receiver details with receiver without wallet", func(t *testing.T) { + // test + route := fmt.Sprintf("/receivers/%s", receiver.ID) + req, err := http.NewRequest("GET", route, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := fmt.Sprintf(`{ + "id": %q, + "external_id": %q, + "email": %q, + "phone_number": %q, + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets": "0", + "wallets": [] + }`, receiver.ID, receiver.ExternalID, *receiver.Email, receiver.PhoneNumber, receiver.CreatedAt.Format(time.RFC3339Nano), receiver.UpdatedAt.Format(time.RFC3339Nano)) + + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + receiverWallet1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet1.ID, data.DraftReceiversWalletStatus) + + message1 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &receiverWallet1.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message2 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet1.ID, + ReceiverWalletID: &receiverWallet1.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + t.Run("successfully returns receiver details with one wallet for given ID", func(t *testing.T) { + disbursement.Name = "disbursement 1" + disbursement.Wallet = wallet1 + d := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &disbursement) + + payment.Status = data.SuccessPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet1 + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &payment) + + // test + route := fmt.Sprintf("/receivers/%s", receiver.ID) + req, err := http.NewRequest("GET", route, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := fmt.Sprintf(`{ + "id": %q, + "external_id": %q, + "email": %q, + "phone_number": %q, + "created_at": %q, + "updated_at": %q, + "total_payments": "1", + "successful_payments": "1", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ], + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet1.com", + "sep_10_client_domain": "www.wallet1.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ] + } + ] + }`, receiver.ID, receiver.ExternalID, *receiver.Email, receiver.PhoneNumber, receiver.CreatedAt.Format(time.RFC3339Nano), + receiver.UpdatedAt.Format(time.RFC3339Nano), receiverWallet1.ID, receiverWallet1.Receiver.ID, receiverWallet1.Wallet.ID, + receiverWallet1.StellarAddress, receiverWallet1.StellarMemo, receiverWallet1.StellarMemoType, + receiverWallet1.CreatedAt.Format(time.RFC3339Nano), receiverWallet1.UpdatedAt.Format(time.RFC3339Nano), + message1.CreatedAt.Format(time.RFC3339Nano), message2.CreatedAt.Format(time.RFC3339Nano)) + + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("successfully returns receiver details with multiple wallets for given ID", func(t *testing.T) { + wallet2 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet2", "https://www.wallet2.com", "www.wallet2.com", "wallet2://") + receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet2.ID, data.RegisteredReceiversWalletStatus) + + message3 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message4 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver.ID, + WalletID: wallet2.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + disbursement.Name = "disbursement 2" + disbursement.Wallet = wallet2 + d := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &disbursement) + + payment.Status = data.DraftPaymentStatus + payment.Disbursement = d + payment.ReceiverWallet = receiverWallet2 + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &payment) + + // test + route := fmt.Sprintf("/receivers/%s", receiver.ID) + req, err := http.NewRequest("GET", route, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := fmt.Sprintf(`{ + "id": %q, + "external_id": %q, + "email": %q, + "phone_number": %q, + "created_at": %q, + "updated_at": %q, + "total_payments": "2", + "successful_payments": "1", + "failed_payments": "0", + "remaining_payments": "1", + "registered_wallets": "1", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ], + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet1.com", + "sep_10_client_domain": "www.wallet1.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ] + }, + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet2", + "homepage": "https://www.wallet2.com", + "sep_10_client_domain": "www.wallet2.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "REGISTERED", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "1", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "0" + } + ] + } + ] + }`, receiver.ID, receiver.ExternalID, *receiver.Email, receiver.PhoneNumber, receiver.CreatedAt.Format(time.RFC3339Nano), + receiver.UpdatedAt.Format(time.RFC3339Nano), receiverWallet1.ID, receiverWallet1.Receiver.ID, + receiverWallet1.Wallet.ID, receiverWallet1.StellarAddress, receiverWallet1.StellarMemo, receiverWallet1.StellarMemoType, + receiverWallet1.CreatedAt.Format(time.RFC3339Nano), receiverWallet1.UpdatedAt.Format(time.RFC3339Nano), + message1.CreatedAt.Format(time.RFC3339Nano), message2.CreatedAt.Format(time.RFC3339Nano), + receiverWallet2.ID, receiverWallet2.Receiver.ID, receiverWallet2.Wallet.ID, + receiverWallet2.StellarAddress, receiverWallet2.StellarMemo, receiverWallet2.StellarMemoType, + receiverWallet2.CreatedAt.Format(time.RFC3339Nano), receiverWallet2.UpdatedAt.Format(time.RFC3339Nano), + message3.CreatedAt.Format(time.RFC3339Nano), message4.CreatedAt.Format(time.RFC3339Nano)) + + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("error receiver not found for given ID", func(t *testing.T) { + // test + req, err := http.NewRequest("GET", "/receivers/invalid_id", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNotFound, rr.Code) + + wantJson := `{ + "error": "could not retrieve receiver with ID: invalid_id" + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) +} + +func Test_ReceiverHandler_GetReceivers_Errors(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &ReceiverHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ts := httptest.NewServer(http.HandlerFunc(handler.GetReceivers)) + defer ts.Close() + + tests := []struct { + name string + queryParams map[string]string + expectedStatusCode int + expectedResponse string + }{ + { + name: "returns error when sort parameter is invalid", + queryParams: map[string]string{ + "sort": "invalid_sort", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"sort":"invalid sort field name"}}`, + }, + { + name: "returns error when direction is invalid", + queryParams: map[string]string{ + "direction": "invalid_direction", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"direction":"invalid sort order. valid values are 'asc' and 'desc'"}}`, + }, + { + name: "returns error when page is invalid", + queryParams: map[string]string{ + "page": "invalid_page", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"page":"parameter must be an integer"}}`, + }, + { + name: "returns error when page_limit is invalid", + queryParams: map[string]string{ + "page_limit": "invalid_page_limit", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"page_limit":"parameter must be an integer"}}`, + }, + { + name: "returns error when status is invalid", + queryParams: map[string]string{ + "status": "invalid_status", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"status":"invalid parameter. valid values are: draft, ready, registered, flagged"}}`, + }, + { + name: "returns error when created_at_after is invalid", + queryParams: map[string]string{ + "created_at_after": "invalid_created_at_after", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"created_at_after":"invalid date format. valid format is 'YYYY-MM-DD'"}}`, + }, + { + name: "returns error when created_at_before is invalid", + queryParams: map[string]string{ + "created_at_before": "invalid_created_at_before", + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponse: `{"error":"request invalid", "extras":{"created_at_before":"invalid date format. valid format is 'YYYY-MM-DD'"}}`, + }, + { + name: "returns empty list when no expectedPayments are found", + queryParams: map[string]string{}, + expectedStatusCode: http.StatusOK, + expectedResponse: `{"data":[], "pagination":{"pages":0, "total": 0}}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build the URL for the test request + url := buildURLWithQueryParams(ts.URL, "/payments", tc.queryParams) + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) + assert.JSONEq(t, tc.expectedResponse, string(respBody)) + }) + } +} + +func Test_ReceiverHandler_GetReceivers_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + handler := &ReceiverHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ts := httptest.NewServer(http.HandlerFunc(handler.GetReceivers)) + defer ts.Close() + + ctx := context.Background() + + // create fixtures + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + // create receivers + date := time.Date(2022, 12, 10, 23, 40, 20, 1431, time.UTC) + receiver1Email := "receiver1@mock.com" + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + Email: &receiver1Email, + ExternalID: "external_id_1", + PhoneNumber: "+99991111", + CreatedAt: &date, + UpdatedAt: &date, + }) + + date = time.Date(2023, 1, 10, 23, 40, 20, 1431, time.UTC) + receiver2Email := "receiver2@mock.com" + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + Email: &receiver2Email, + ExternalID: "external_id_2", + PhoneNumber: "+99992222", + CreatedAt: &date, + UpdatedAt: &date, + }) + receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + message1 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver2.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message2 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver2.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + date = time.Date(2023, 2, 10, 23, 40, 21, 1431, time.UTC) + receiver3Email := "receiver3@mock.com" + receiver3 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + Email: &receiver3Email, + ExternalID: "external_id_3", + PhoneNumber: "+99993333", + CreatedAt: &date, + UpdatedAt: &date, + }) + receiverWallet3 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver3.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + message3 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver3.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet3.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message4 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver3.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet3.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + date = time.Date(2023, 3, 10, 23, 40, 20, 1431, time.UTC) + receiver4Email := "receiver4@mock.com" + receiver4 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + Email: &receiver4Email, + ExternalID: "external_id_4", + PhoneNumber: "+99994444", + CreatedAt: &date, + UpdatedAt: &date, + }) + receiverWallet4 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver4.ID, wallet.ID, data.DraftReceiversWalletStatus) + + message5 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver4.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet4.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message6 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver4.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet4.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + // create disbursements + disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + // create payments + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.SuccessPaymentStatus, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet2, + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + Disbursement: disbursement1, + Asset: *asset, + ReceiverWallet: receiverWallet3, + }) + + tests := []struct { + name string + queryParams map[string]string + expectedStatusCode int + expectedResponse string + }{ + { + name: "fetch all receivers without filters", + queryParams: map[string]string{}, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 4 + }, + "data": [ + { + "id": %q, + "email": "receiver4@mock.com", + "external_id": "external_id_4", + "phone_number": "+99994444", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "0", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "0" + } + ] + }, + { + "id": %q, + "email": "receiver3@mock.com", + "external_id": "external_id_3", + "phone_number": "+99993333", + "created_at": %q, + "updated_at": %q, + "total_payments": "1", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "1", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "0" + } + ], + "registered_wallets":"1", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "REGISTERED", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "1", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "0" + } + ] + } + ] + }, + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external_id_2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "1", + "successful_payments": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ], + "registered_wallets":"1", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "REGISTERED", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ] + } + ] + }, + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external_id_1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [] + } + ] + }`, + receiver4.ID, receiver4.CreatedAt.Format(time.RFC3339Nano), receiver4.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet4.ID, receiverWallet4.Receiver.ID, receiverWallet4.Wallet.ID, + receiverWallet4.StellarAddress, receiverWallet4.StellarMemo, receiverWallet4.StellarMemoType, + receiverWallet4.CreatedAt.Format(time.RFC3339Nano), receiverWallet4.UpdatedAt.Format(time.RFC3339Nano), + message5.CreatedAt.Format(time.RFC3339Nano), message6.CreatedAt.Format(time.RFC3339Nano), + receiver3.ID, receiver3.CreatedAt.Format(time.RFC3339Nano), receiver3.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet3.ID, receiverWallet3.Receiver.ID, receiverWallet3.Wallet.ID, + receiverWallet3.StellarAddress, receiverWallet3.StellarMemo, receiverWallet3.StellarMemoType, + receiverWallet3.CreatedAt.Format(time.RFC3339Nano), receiverWallet3.UpdatedAt.Format(time.RFC3339Nano), + message3.CreatedAt.Format(time.RFC3339Nano), message4.CreatedAt.Format(time.RFC3339Nano), + receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet2.ID, receiverWallet2.Receiver.ID, receiverWallet2.Wallet.ID, receiverWallet2.StellarAddress, + receiverWallet2.StellarMemo, receiverWallet2.StellarMemoType, + receiverWallet2.CreatedAt.Format(time.RFC3339Nano), receiverWallet2.UpdatedAt.Format(time.RFC3339Nano), + message1.CreatedAt.Format(time.RFC3339Nano), message2.CreatedAt.Format(time.RFC3339Nano), + receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch first page of receivers with limit 1 and sort by created_at", + queryParams: map[string]string{ + "page": "1", + "page_limit": "1", + "sort": "created_at", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "next": "/receivers?direction=asc\u0026page=2\u0026page_limit=1\u0026sort=created_at", + "pages": 4, + "total": 4 + }, + "data": [ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external_id_1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [] + } + ] + }`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch second page of receivers with limit 1 and sort by created_at", + queryParams: map[string]string{ + "page": "2", + "page_limit": "1", + "sort": "created_at", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "prev": "/receivers?direction=asc\u0026page=1\u0026page_limit=1\u0026sort=created_at", + "next": "/receivers?direction=asc\u0026page=3\u0026page_limit=1\u0026sort=created_at", + "pages": 4, + "total": 4 + }, + "data": [ + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external_id_2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "1", + "successful_payments": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ], + "registered_wallets":"1", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "REGISTERED", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ] + } + ] + } + ] + }`, receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet2.ID, receiverWallet2.Receiver.ID, receiverWallet2.Wallet.ID, + receiverWallet2.StellarAddress, receiverWallet2.StellarMemo, receiverWallet2.StellarMemoType, + receiverWallet2.CreatedAt.Format(time.RFC3339Nano), receiverWallet2.UpdatedAt.Format(time.RFC3339Nano), + message1.CreatedAt.Format(time.RFC3339Nano), message2.CreatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch last page of receivers with limit 1 and sort by created_at", + queryParams: map[string]string{ + "page": "4", + "page_limit": "1", + "sort": "created_at", + "direction": "asc", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "prev": "/receivers?direction=asc\u0026page=3\u0026page_limit=1\u0026sort=created_at", + "pages": 4, + "total": 4 + }, + "data": [ + { + "id": %q, + "email": "receiver4@mock.com", + "external_id": "external_id_4", + "phone_number": "+99994444", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "0", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "0" + } + ] + } + ] + }`, receiver4.ID, receiver4.CreatedAt.Format(time.RFC3339Nano), receiver4.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet4.ID, receiverWallet4.Receiver.ID, receiverWallet4.Wallet.ID, + receiverWallet4.StellarAddress, receiverWallet4.StellarMemo, receiverWallet4.StellarMemoType, + receiverWallet4.CreatedAt.Format(time.RFC3339Nano), receiverWallet4.UpdatedAt.Format(time.RFC3339Nano), + message5.CreatedAt.Format(time.RFC3339Nano), message6.CreatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch receivers with status draft", + queryParams: map[string]string{ + "status": "dRaFt", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 1 + }, + "data": [ + { + "id": %q, + "email": "receiver4@mock.com", + "external_id": "external_id_4", + "phone_number": "+99994444", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "0", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "0" + } + ] + } + ] + }`, receiver4.ID, receiver4.CreatedAt.Format(time.RFC3339Nano), receiver4.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet4.ID, receiverWallet4.Receiver.ID, receiverWallet4.Wallet.ID, + receiverWallet4.StellarAddress, receiverWallet4.StellarMemo, receiverWallet4.StellarMemoType, + receiverWallet4.CreatedAt.Format(time.RFC3339Nano), receiverWallet4.UpdatedAt.Format(time.RFC3339Nano), + message5.CreatedAt.Format(time.RFC3339Nano), message6.CreatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch receivers created before 2023-01-01", + queryParams: map[string]string{ + "created_at_before": "2023-01-01", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 1 + }, + "data": [ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external_id_1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [] + } + ] + }`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch receivers created after 2023-03-01", + queryParams: map[string]string{ + "created_at_after": "2023-03-01", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 1 + }, + "data": [ + { + "id": %q, + "email": "receiver4@mock.com", + "external_id": "external_id_4", + "phone_number": "+99994444", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "0", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "0" + } + ] + } + ] + }`, receiver4.ID, receiver4.CreatedAt.Format(time.RFC3339Nano), receiver4.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet4.ID, receiverWallet4.Receiver.ID, receiverWallet4.Wallet.ID, + receiverWallet4.StellarAddress, receiverWallet4.StellarMemo, receiverWallet4.StellarMemoType, + receiverWallet4.CreatedAt.Format(time.RFC3339Nano), receiverWallet4.UpdatedAt.Format(time.RFC3339Nano), + message5.CreatedAt.Format(time.RFC3339Nano), message6.CreatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch receivers created after 2023-01-01 and before 2023-03-01", + queryParams: map[string]string{ + "created_at_after": "2023-01-01", + "created_at_before": "2023-03-01", + "sort": "created_at", + "direction": "desc", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 2 + }, + "data": [ + { + "id": %q, + "email": "receiver3@mock.com", + "external_id": "external_id_3", + "phone_number": "+99993333", + "created_at": %q, + "updated_at": %q, + "total_payments": "1", + "successful_payments": "0", + "received_amounts": "0", + "failed_payments": "0", + "remaining_payments": "1", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "0" + } + ], + "registered_wallets":"1", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "REGISTERED", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "1", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "0" + } + ] + } + ] + }, + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external_id_2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "1", + "successful_payments": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ], + "registered_wallets":"1", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "REGISTERED", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "1", + "payments_received": "1", + "failed_payments": "0", + "remaining_payments": "0", + "received_amounts": [ + { + "asset_code": "USDC", + "asset_issuer": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV", + "received_amount": "50.0000000" + } + ] + } + ] + } + ] + }`, receiver3.ID, receiver3.CreatedAt.Format(time.RFC3339Nano), receiver3.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet3.ID, receiverWallet3.Receiver.ID, receiverWallet3.Wallet.ID, + receiverWallet3.StellarAddress, receiverWallet3.StellarMemo, receiverWallet3.StellarMemoType, + receiverWallet3.CreatedAt.Format(time.RFC3339Nano), receiverWallet3.UpdatedAt.Format(time.RFC3339Nano), + message3.CreatedAt.Format(time.RFC3339Nano), message4.CreatedAt.Format(time.RFC3339Nano), + receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet2.ID, receiverWallet2.Receiver.ID, receiverWallet2.Wallet.ID, + receiverWallet2.StellarAddress, receiverWallet2.StellarMemo, receiverWallet2.StellarMemoType, + receiverWallet2.CreatedAt.Format(time.RFC3339Nano), receiverWallet2.UpdatedAt.Format(time.RFC3339Nano), + message1.CreatedAt.Format(time.RFC3339Nano), message2.CreatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch receivers with email = receiver1@mock.com", + queryParams: map[string]string{ + "q": receiver1Email, + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 1 + }, + "data": [ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external_id_1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [] + } + ] + }`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)), + }, + { + name: "fetch receivers with phone_number = +99991111", + queryParams: map[string]string{ + "q": "+99991111", + }, + expectedStatusCode: http.StatusOK, + expectedResponse: fmt.Sprintf(`{ + "pagination": { + "pages": 1, + "total": 1 + }, + "data": [ + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external_id_1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [] + } + ] + }`, receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano)), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Build the URL for the test request + url := buildURLWithQueryParams(ts.URL, "/receivers", tc.queryParams) + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tc.expectedStatusCode, resp.StatusCode) + + assert.JSONEq(t, tc.expectedResponse, string(respBody)) + }) + } +} + +func Test_ReceiverHandler_BuildReceiversResponse(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &ReceiverHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ctx := context.Background() + + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver1Email := "receiver1@mock.com" + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + Email: &receiver1Email, + ExternalID: "external_id_1", + PhoneNumber: "+99991111", + }) + receiver2Email := "receiver2@mock.com" + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + Email: &receiver2Email, + ExternalID: "external_id_2", + PhoneNumber: "+99992222", + }) + + receiverWallet1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.DraftReceiversWalletStatus) + receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + message1 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver1.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet1.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message2 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver1.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet1.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + message3 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver2.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 1, 10, 23, 40, 20, 1000, time.UTC), + }) + + message4 := data.CreateMessageFixture(t, ctx, dbConnectionPool, &data.Message{ + Type: message.MessengerTypeTwilioSMS, + AssetID: nil, + ReceiverID: receiver2.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWallet2.ID, + Status: data.SuccessMessageStatus, + CreatedAt: time.Date(2023, 2, 10, 23, 40, 20, 1000, time.UTC), + }) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + + receivers, err := handler.Models.Receiver.GetAll(ctx, dbTx, &data.QueryParams{SortBy: data.SortFieldUpdatedAt, SortOrder: data.SortOrderDESC}) + require.NoError(t, err) + receiversId := handler.Models.Receiver.ParseReceiverIDs(receivers) + receiversWallets, err := handler.Models.ReceiverWallet.GetWithReceiverIds(ctx, dbTx, receiversId) + require.NoError(t, err) + + actualResponse := handler.buildReceiversResponse(receivers, receiversWallets) + + ar, err := json.Marshal(actualResponse) + require.NoError(t, err) + + wantJson := fmt.Sprintf(`[ + { + "id": %q, + "email": "receiver2@mock.com", + "external_id": "external_id_2", + "phone_number": "+99992222", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "READY", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "0", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "0" + } + ] + }, + { + "id": %q, + "email": "receiver1@mock.com", + "external_id": "external_id_1", + "phone_number": "+99991111", + "created_at": %q, + "updated_at": %q, + "total_payments": "0", + "successful_payments": "0", + "failed_payments": "0", + "remaining_payments": "0", + "registered_wallets":"0", + "wallets": [ + { + "id": %q, + "receiver": { + "id": %q + }, + "wallet": { + "id": %q, + "name": "wallet1", + "homepage": "https://www.wallet.com", + "sep_10_client_domain": "www.wallet.com" + }, + "stellar_address": %q, + "stellar_memo": %q, + "stellar_memo_type": %q, + "status": "DRAFT", + "created_at": %q, + "updated_at": %q, + "invited_at": %q, + "last_sms_sent": %q, + "total_payments": "0", + "payments_received": "0", + "failed_payments": "0", + "remaining_payments": "0" + } + ] + } + ]`, receiver2.ID, receiver2.CreatedAt.Format(time.RFC3339Nano), receiver2.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet2.ID, receiverWallet2.Receiver.ID, receiverWallet2.Wallet.ID, + receiverWallet2.StellarAddress, receiverWallet2.StellarMemo, receiverWallet2.StellarMemoType, + receiverWallet2.CreatedAt.Format(time.RFC3339Nano), receiverWallet2.UpdatedAt.Format(time.RFC3339Nano), + message3.CreatedAt.Format(time.RFC3339Nano), message4.CreatedAt.Format(time.RFC3339Nano), + receiver1.ID, receiver1.CreatedAt.Format(time.RFC3339Nano), receiver1.UpdatedAt.Format(time.RFC3339Nano), + receiverWallet1.ID, receiverWallet1.Receiver.ID, receiverWallet1.Wallet.ID, + receiverWallet1.StellarAddress, receiverWallet1.StellarMemo, receiverWallet1.StellarMemoType, + receiverWallet1.CreatedAt.Format(time.RFC3339Nano), receiverWallet1.UpdatedAt.Format(time.RFC3339Nano), + message1.CreatedAt.Format(time.RFC3339Nano), message2.CreatedAt.Format(time.RFC3339Nano)) + + assert.JSONEq(t, wantJson, string(ar)) + + err = dbTx.Commit() + require.NoError(t, err) +} diff --git a/internal/serve/httphandler/receiver_registration.go b/internal/serve/httphandler/receiver_registration.go new file mode 100644 index 000000000..79700bfa6 --- /dev/null +++ b/internal/serve/httphandler/receiver_registration.go @@ -0,0 +1,88 @@ +package httphandler + +import ( + "errors" + "fmt" + "net/http" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + htmlTpl "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" +) + +type ReceiverRegistrationHandler struct { + ReceiverWalletModel *data.ReceiverWalletModel + ReCAPTCHASiteKey string +} + +type ReceiverRegistrationData struct { + StellarAccount string + JWTToken string + Title string + Message string + ReCAPTCHASiteKey string +} + +// ServeHTTP will serve the SEP-24 deposit page needed to register users. +func (h ReceiverRegistrationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + sep24Claims := anchorplatform.GetSEP24Claims(ctx) + if sep24Claims == nil { + err := fmt.Errorf("no SEP-24 claims found in the request context") + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(w) + return + } + + token := r.URL.Query().Get("token") + if token == "" { + err := fmt.Errorf("no token was provided in the request") + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(w) + return + } + + err := sep24Claims.Valid() + if err != nil { + err = fmt.Errorf("SEP-24 claims are invalid: %w", err) + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(w) + return + } + + rw, err := h.ReceiverWalletModel.GetByStellarAccountAndMemo(ctx, sep24Claims.SEP10StellarAccount(), sep24Claims.SEP10StellarMemo()) + if err != nil && !errors.Is(err, data.ErrRecordNotFound) { + httperror.InternalError(ctx, "Cannot register receiver wallet", err, nil).Render(w) + return + } + + tmplData := ReceiverRegistrationData{ + StellarAccount: sep24Claims.SEP10StellarAccount(), + JWTToken: token, + ReCAPTCHASiteKey: h.ReCAPTCHASiteKey, + } + + htmlTemplateName := "receiver_register.tmpl" + if rw != nil { + // If the user was previously registered successfully, load a different template. + htmlTemplateName = "receiver_registered_successfully.tmpl" + tmplData.Title = "Registration Complete πŸŽ‰" + tmplData.Message = "Your Stellar wallet has been registered successfully!" + } + + registerPage, err := htmlTpl.ExecuteHTMLTemplate(htmlTemplateName, tmplData) + if err != nil { + httperror.InternalError(ctx, "Cannot process the html template for request", err, nil).Render(w) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, err = w.Write([]byte(registerPage)) + if err != nil { + httperror.InternalError(ctx, "Cannot write html content to response", err, nil).Render(w) + return + } +} diff --git a/internal/serve/httphandler/receiver_registration_test.go b/internal/serve/httphandler/receiver_registration_test.go new file mode 100644 index 000000000..ccf05521c --- /dev/null +++ b/internal/serve/httphandler/receiver_registration_test.go @@ -0,0 +1,135 @@ +package httphandler + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/golang-jwt/jwt/v4" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReceiverRegistrationHandler_ServeHTTP(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + receiverWalletModel := models.ReceiverWallet + reCAPTCHASiteKey := "reCAPTCHASiteKey" + + r := chi.NewRouter() + r.Get("/receiver-registration/start", ReceiverRegistrationHandler{ReceiverWalletModel: receiverWalletModel, ReCAPTCHASiteKey: reCAPTCHASiteKey}.ServeHTTP) + + t.Run("returns 401 - Unauthorized if the token is not in the request context", func(t *testing.T) { + req, err := http.NewRequest("GET", "/receiver-registration/start", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns 401 - Unauthorized if the token is in the request context but it's not valid", func(t *testing.T) { + req, err := http.NewRequest("GET", "/receiver-registration/start", nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + invalidClaims := &anchorplatform.SEP24JWTClaims{} + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, invalidClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns 200 - Ok (And show the Wallet Registration page) if the token is in the request context and it's valid πŸŽ‰", func(t *testing.T) { + req, err := http.NewRequest("GET", "/receiver-registration/start?token=test-token", nil) + require.NoError(t, err) + + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: "test.com", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) + assert.Contains(t, string(respBody), "Wallet Registration") + assert.Contains(t, string(respBody), `
`) + assert.Contains(t, string(respBody), ``) + }) + + t.Run("returns 200 - Ok (And show the Registration Success page) if the token is in the request context and it's valid and the user was already registered πŸŽ‰", func(t *testing.T) { + req, err := http.NewRequest("GET", "/receiver-registration/start?token=test-token", nil) + require.NoError(t, err) + + ctx := context.Background() + + // Create a receiver wallet + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, + "My Wallet", + "https://mywallet.com", + "mywallet.com", + "mywallet://") + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + receiverWallet.StellarAddress = "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444" + receiverWallet.StellarMemo = "" + err = receiverWalletModel.UpdateReceiverWallet(ctx, *receiverWallet, dbConnectionPool) + require.NoError(t, err) + + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: "mywallet.com", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) + assert.Contains(t, string(respBody), "Wallet Registration Confirmation") + }) +} diff --git a/internal/serve/httphandler/receiver_send_otp_handler.go b/internal/serve/httphandler/receiver_send_otp_handler.go new file mode 100644 index 000000000..f57745f09 --- /dev/null +++ b/internal/serve/httphandler/receiver_send_otp_handler.go @@ -0,0 +1,133 @@ +package httphandler + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + htmlTpl "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type ReceiverSendOTPHandler struct { + Models *data.Models + SMSMessengerClient message.MessengerClient + ReCAPTCHAValidator validators.ReCAPTCHAValidator +} + +type ReceiverSendOTPData struct { + OTP string +} + +type ReceiverSendOTPRequest struct { + PhoneNumber string `json:"phone_number"` + ReCAPTCHAToken string `json:"recaptcha_token"` +} + +type ReceiverSendOTPResponseBody struct { + Message string `json:"message"` +} + +func (h ReceiverSendOTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + receiverSendOTPRequest := ReceiverSendOTPRequest{} + + err := json.NewDecoder(r.Body).Decode(&receiverSendOTPRequest) + if err != nil { + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + // validating reCAPTCHA Token + isValid, err := h.ReCAPTCHAValidator.IsTokenValid(ctx, receiverSendOTPRequest.ReCAPTCHAToken) + if err != nil { + httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", err, nil).Render(w) + return + } + + if !isValid { + log.Ctx(ctx).Errorf("reCAPTCHA token is invalid") + httperror.BadRequest("reCAPTCHA token is invalid", nil, nil).Render(w) + return + } + + // validate request + v := validators.NewValidator() + + v.Check(receiverSendOTPRequest.PhoneNumber != "", "phone_number", "phone_number is required") + + if v.HasErrors() { + httperror.BadRequest("request invalid", err, v.Errors).Render(w) + return + } + + // Get clains from SEP24 JWT + sep24Claims := anchorplatform.GetSEP24Claims(ctx) + if sep24Claims == nil { + err = fmt.Errorf("no SEP-24 claims found in the request context") + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(w) + return + } + + err = sep24Claims.Valid() + if err != nil { + err = fmt.Errorf("SEP-24 claims are invalid: %w", err) + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(w) + return + } + + // Generate a new 6 digits OTP + newOTP, err := utils.RandomString(6, utils.NumberBytes) + if err != nil { + httperror.InternalError(ctx, "Cannot generate OTP for receiver wallet", err, nil).Render(w) + return + } + + numberOfUpdatedRows, err := h.Models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, receiverSendOTPRequest.PhoneNumber, sep24Claims.ClientDomainClaim, newOTP) + if err != nil { + httperror.InternalError(ctx, "Cannot update OTP for receiver wallet", err, nil).Render(w) + return + } + + if numberOfUpdatedRows < 1 { + log.Ctx(ctx).Warnf("updated no rows in receiver send OTP handler for phone number: %s", utils.TruncateString(receiverSendOTPRequest.PhoneNumber, len(receiverSendOTPRequest.PhoneNumber)/4)) + } else { + // Build the data object that will be injected in message template + sendOTPData := ReceiverSendOTPData{ + OTP: newOTP, + } + + sendOTPMessage, err := htmlTpl.ExecuteHTMLTemplate("receiver_otp_message.tmpl", sendOTPData) + if err != nil { + httperror.InternalError(ctx, "Cannot execute OTP template", err, nil).Render(w) + return + } + + smsMessage := message.Message{ + ToPhoneNumber: receiverSendOTPRequest.PhoneNumber, + Message: sendOTPMessage, + } + + log.Ctx(ctx).Infof("sending OTP message to phone number: %s", utils.TruncateString(receiverSendOTPRequest.PhoneNumber, 3)) + err = h.SMSMessengerClient.SendMessage(smsMessage) + if err != nil { + httperror.InternalError(ctx, "Cannot send OTP message", err, nil).Render(w) + return + } + } + + response := ReceiverSendOTPResponseBody{ + Message: "if your phone number is registered, you'll receive an OTP", + } + httpjson.RenderStatus(w, http.StatusOK, response, httpjson.JSON) +} diff --git a/internal/serve/httphandler/receiver_send_otp_handler_test.go b/internal/serve/httphandler/receiver_send_otp_handler_test.go new file mode 100644 index 000000000..c81ec6d6d --- /dev/null +++ b/internal/serve/httphandler/receiver_send_otp_handler_test.go @@ -0,0 +1,280 @@ +package httphandler + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/golang-jwt/jwt/v4" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockMessengerClient struct { + mock.Mock +} + +func (m *mockMessengerClient) SendMessage(message message.Message) error { + return m.Called(message).Error(0) +} + +func (mc *mockMessengerClient) MessengerType() message.MessengerType { + args := mc.Called() + return args.Get(0).(message.MessengerType) +} + +func Test_ReceiverSendOTPHandler_ServeHTTP(t *testing.T) { + r := chi.NewRouter() + + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + wallet1 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://home.page", "home.page", "wallet123://") + + _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, data.RegisteredReceiversWalletStatus) + _ = data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet1.ID, data.RegisteredReceiversWalletStatus) + + mockMessenger := mockMessengerClient{} + reCAPTCHAValidator := &validators.ReCAPTCHAValidatorMock{} + + r.Post("/wallet-registration/otp", ReceiverSendOTPHandler{ + Models: models, + SMSMessengerClient: &mockMessenger, + ReCAPTCHAValidator: reCAPTCHAValidator, + }.ServeHTTP) + + requestSendOTP := ReceiverSendOTPRequest{ + PhoneNumber: receiver1.PhoneNumber, + ReCAPTCHAToken: "XyZ", + } + reqBody, err := json.Marshal(requestSendOTP) + require.NoError(t, err) + + t.Run("returns 401 - Unauthorized if the token is not in the request context", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + req, err := http.NewRequest("POST", "/wallet-registration/otp", strings.NewReader(string(reqBody))) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns 401 - Unauthorized if the token is in the request context but it's not valid", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + req, err := http.NewRequest("POST", "/wallet-registration/otp", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + invalidClaims := &anchorplatform.SEP24JWTClaims{} + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, invalidClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns 400 - BadRequest with a wrong request body", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + invalidRequest := `{"recaptcha_token": "XyZ"}` + + req, err := http.NewRequest("POST", "/wallet-registration/otp", strings.NewReader(invalidRequest)) + require.NoError(t, err) + + rr := httptest.NewRecorder() + invalidClaims := &anchorplatform.SEP24JWTClaims{} + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, invalidClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error":"request invalid","extras":{"phone_number":"phone_number is required"}}`, string(respBody)) + }) + + t.Run("returns 200 - Ok if the token is in the request context and body it's valid", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + req, err := http.NewRequest("POST", "/wallet-registration/otp", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet1.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + mockMessenger.On("SendMessage", mock.AnythingOfType("message.Message")). + Return(nil). + Once() + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, resp.Header.Get("Content-Type"), "/json; charset=utf-8") + assert.JSONEq(t, string(respBody), `{"message":"if your phone number is registered, you'll receive an OTP"}`) + }) + + t.Run("returns 500 - InternalServerError when something goes wrong when sending the SMS", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(true, nil). + Once() + req, err := http.NewRequest("POST", "/wallet-registration/otp", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet1.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + mockMessenger.On("SendMessage", mock.AnythingOfType("message.Message")). + Return(errors.New("error sending message")). + Once() + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.Contains(t, resp.Header.Get("Content-Type"), "/json; charset=utf-8") + assert.JSONEq(t, string(respBody), `{"error":"Cannot send OTP message"}`) + }) + + t.Run("returns 500 - InternalServerError when unable to validate recaptcha", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(false, errors.New("error requesting verify reCAPTCHA token")). + Once() + + req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet1.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Cannot validate reCAPTCHA token" + } + ` + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("returns 400 - BadRequest when recaptcha token is invalid", func(t *testing.T) { + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "XyZ"). + Return(false, nil). + Once() + + req, err := http.NewRequest(http.MethodPost, "/wallet-registration/otp", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet1.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "reCAPTCHA token is invalid" + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + mockMessenger.AssertExpectations(t) + reCAPTCHAValidator.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/refresh_token_handler.go b/internal/serve/httphandler/refresh_token_handler.go new file mode 100644 index 000000000..e20f9c9b8 --- /dev/null +++ b/internal/serve/httphandler/refresh_token_handler.go @@ -0,0 +1,38 @@ +package httphandler + +import ( + "errors" + "net/http" + + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +type RefreshTokenHandler struct { + AuthManager auth.AuthManager +} + +func (h RefreshTokenHandler) PostRefreshToken(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + refreshedToken, err := h.AuthManager.RefreshToken(ctx, token) + if err != nil { + if errors.Is(err, auth.ErrInvalidToken) { + httperror.BadRequest("", err, map[string]interface{}{"token": "token is invalid"}).Render(rw) + return + } + + httperror.InternalError(ctx, "Cannot refresh user token", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"token": refreshedToken}, httpjson.JSON) +} diff --git a/internal/serve/httphandler/refresh_token_handler_test.go b/internal/serve/httphandler/refresh_token_handler_test.go new file mode 100644 index 000000000..96524edba --- /dev/null +++ b/internal/serve/httphandler/refresh_token_handler_test.go @@ -0,0 +1,118 @@ +package httphandler + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_RefreshTokenHandler(t *testing.T) { + jwtManagerMock := &auth.JWTManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomJWTManagerOption(jwtManagerMock), + ) + + handler := &RefreshTokenHandler{AuthManager: authManager} + url := "/refresh-token" + + ctx := context.Background() + + t.Run("returns Unauthorized error when no token is found", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + + http.HandlerFunc(handler.PostRefreshToken).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + }) + + t.Run("returns BadRequest when token is expired", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), "mytoken"). + Return(false, nil). + Once() + + http.HandlerFunc(handler.PostRefreshToken).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"token": "token is invalid"}}`, string(respBody)) + }) + + t.Run("returns InternalServerError when AuthManager fails", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), "mytoken"). + Return(false, errors.New("unexpected error")). + Once() + + http.HandlerFunc(handler.PostRefreshToken).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot refresh user token"}`, string(respBody)) + }) + + t.Run("returns the refreshed token", func(t *testing.T) { + ctx = context.WithValue(ctx, middleware.TokenContextKey, "mytoken") + + w := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), "mytoken"). + Return(true, nil). + Once(). + On("RefreshToken", req.Context(), "mytoken", mock.AnythingOfType("time.Time")). + Return("myrefreshedtoken", nil). + Once() + + http.HandlerFunc(handler.PostRefreshToken).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"token": "myrefreshedtoken"}`, string(respBody)) + }) + + jwtManagerMock.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/reset_password_handler.go b/internal/serve/httphandler/reset_password_handler.go new file mode 100644 index 000000000..663ee9c75 --- /dev/null +++ b/internal/serve/httphandler/reset_password_handler.go @@ -0,0 +1,60 @@ +package httphandler + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +// ResetPasswordHandler resets the user password by receiving a valid reset token +// and the new password. +type ResetPasswordHandler struct { + AuthManager auth.AuthManager +} + +type ResetPasswordRequest struct { + Password string `json:"password"` + ResetToken string `json:"reset_token"` +} + +// ServeHTTP implements the http.Handler interface. +func (h ResetPasswordHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var resetPasswordRequest ResetPasswordRequest + + err := json.NewDecoder(r.Body).Decode(&resetPasswordRequest) + if err != nil { + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + // validate request + v := validators.NewValidator() + + v.Check(resetPasswordRequest.Password != "", "password", "password is required") + v.Check(resetPasswordRequest.Password != "", "reset_token", "reset token is required") + + if v.HasErrors() { + httperror.BadRequest("request invalid", err, v.Errors).Render(w) + return + } + + ctx := r.Context() + + // Reset password email with a valid token + err = h.AuthManager.ResetPassword(ctx, resetPasswordRequest.ResetToken, resetPasswordRequest.Password) + if err != nil { + if errors.Is(err, auth.ErrInvalidResetPasswordToken) { + httperror.BadRequest("invalid reset password token", err, nil).Render(w) + return + } + httperror.InternalError(ctx, "Cannot reset password", err, v.Errors).Render(w) + return + } + + httpjson.RenderStatus(w, http.StatusOK, nil, httpjson.JSON) +} diff --git a/internal/serve/httphandler/reset_password_handler_test.go b/internal/serve/httphandler/reset_password_handler_test.go new file mode 100644 index 000000000..58d6cf880 --- /dev/null +++ b/internal/serve/httphandler/reset_password_handler_test.go @@ -0,0 +1,97 @@ +package httphandler + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ResetPasswordHandlerPost(t *testing.T) { + const url = "/reset-password" + const method = "POST" + + authenticatorMock := &auth.AuthenticatorMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + ) + + handler := &ResetPasswordHandler{ + AuthManager: authManager, + } + + t.Run("Should return http status 200 on a valid request", func(t *testing.T) { + requestBody := `{ "password": "password123", "reset_token": "goodtoken" }` + + rr := httptest.NewRecorder() + req, _ := http.NewRequest(method, url, strings.NewReader(requestBody)) + + authenticatorMock. + On("ResetPassword", req.Context(), "goodtoken", "password123"). + Return(nil). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Should return an error with an invalid token", func(t *testing.T) { + requestBody := `{"password":"password123","reset_token":"badtoken"}` + + rr := httptest.NewRecorder() + req, _ := http.NewRequest(method, url, strings.NewReader(requestBody)) + + authenticatorMock. + On("ResetPassword", req.Context(), "badtoken", "password123"). + Return(auth.ErrInvalidResetPasswordToken). + Once() + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := ` + { + "error": "invalid reset password token" + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, expectedBody, string(respBody)) + }) + + t.Run("Should require both password and reset_token params", func(t *testing.T) { + requestBody := `{"password":""}` + + rr := httptest.NewRecorder() + req, _ := http.NewRequest(method, url, strings.NewReader(requestBody)) + + http.HandlerFunc(handler.ServeHTTP).ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expectedBody := ` + { + "error":"request invalid", + "extras": { + "password":"password is required", + "reset_token":"reset token is required" + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, expectedBody, string(respBody)) + }) + + authenticatorMock.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/statistics_handler.go b/internal/serve/httphandler/statistics_handler.go new file mode 100644 index 000000000..3ebf6c8bb --- /dev/null +++ b/internal/serve/httphandler/statistics_handler.go @@ -0,0 +1,48 @@ +package httphandler + +import ( + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/statistics" +) + +type StatisticsHandler struct { + DBConnectionPool db.DBConnectionPool +} + +func (s StatisticsHandler) GetStatistics(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + stats, err := statistics.CalculateStatistics(ctx, s.DBConnectionPool) + if err != nil { + httperror.InternalError(ctx, "Cannot calculate statistics", err, nil).Render(w) + return + } + + httpjson.RenderStatus(w, http.StatusOK, stats, httpjson.JSON) +} + +func (s StatisticsHandler) GetStatisticsByDisbursement(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + disbursementID := chi.URLParam(r, "id") + + stats, err := statistics.CalculateStatisticsByDisbursement(ctx, s.DBConnectionPool, disbursementID) + if err != nil { + if errors.Is(statistics.ErrResourcesNotFound, err) { + errorMsg := fmt.Sprintf("a disbursement with the id %s does not exist", disbursementID) + httperror.NotFound(errorMsg, err, nil).Render(w) + return + } else { + httperror.InternalError(ctx, "Cannot calculate statistics", err, nil).Render(w) + return + } + } + + httpjson.RenderStatus(w, http.StatusOK, stats, httpjson.JSON) +} diff --git a/internal/serve/httphandler/statistics_handler_test.go b/internal/serve/httphandler/statistics_handler_test.go new file mode 100644 index 000000000..597d3792e --- /dev/null +++ b/internal/serve/httphandler/statistics_handler_test.go @@ -0,0 +1,245 @@ +package httphandler + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatisticsHandler(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + // setup + statisticsHandler := StatisticsHandler{DBConnectionPool: dbConnectionPool} + r := chi.NewRouter() + r.Get("/statistics", statisticsHandler.GetStatistics) + r.Get("/statistics/{id}", statisticsHandler.GetStatisticsByDisbursement) + + t.Run("get statistics with no data", func(t *testing.T) { + // test + var req *http.Request + req, err = http.NewRequest("GET", "/statistics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := `{ + "payment_counters": { + "draft": 0, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 0, + "failed": 0, + "total": 0 + }, + "payment_amounts_by_asset": [], + "receiver_wallets_counters": { + "draft": 0, + "ready": 0, + "registered": 0, + "flagged": 0, + "total": 0 + }, + "total_receivers": 0, + "total_disbursements": 0 + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("get statistics for invalid disbursement id", func(t *testing.T) { + // test + var req *http.Request + req, err = http.NewRequest("GET", "/statistics/invalid-id", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNotFound, rr.Code) + + wantJson := `{ + "error": "a disbursement with the id invalid-id does not exist" + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + ctx := context.Background() + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + asset1 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.CompletedDisbursementStatus, + Asset: asset1, + Wallet: wallet, + Country: country, + }) + + t.Run("get statistics for existing disbursement with no data", func(t *testing.T) { + // test + var req *http.Request + req, err = http.NewRequest("GET", "/statistics/"+disbursement.ID, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := `{ + "payment_counters": { + "draft": 0, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 0, + "failed": 0, + "total": 0 + }, + "payment_amounts_by_asset": [], + "receiver_wallets_counters": { + "draft": 0, + "ready": 0, + "registered": 0, + "flagged": 0, + "total": 0 + }, + "total_receivers": 0 + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "10", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + Disbursement: disbursement, + Asset: *asset1, + ReceiverWallet: receiverWallet, + }) + + t.Run("get statistics", func(t *testing.T) { + // test + req, err := http.NewRequest("GET", "/statistics", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + wantJson := `{ + "payment_counters": { + "draft": 1, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 0, + "failed": 0, + "total": 1 + }, + "payment_amounts_by_asset": [ + { + "asset_code": "USDC", + "payment_amounts": { + "draft": "10.0000000", + "ready": "", + "pending": "", + "paused": "", + "success": "", + "failed": "", + "average": "10.0000000", + "total": "10.0000000" + } + } + ], + "receiver_wallets_counters": { + "draft": 1, + "ready": 0, + "registered": 0, + "flagged": 0, + "total": 1 + }, + "total_receivers": 1, + "total_disbursements": 1 + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) + + t.Run("get statistics for specific disbursement", func(t *testing.T) { + route := fmt.Sprintf("/statistics/%s", disbursement.ID) + req, err := http.NewRequest("GET", route, nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + + wantJson := `{ + "payment_counters": { + "draft": 1, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 0, + "failed": 0, + "total": 1 + }, + "payment_amounts_by_asset": [ + { + "asset_code": "USDC", + "payment_amounts": { + "draft": "10.0000000", + "ready": "", + "pending": "", + "paused": "", + "success": "", + "failed": "", + "average": "10.0000000", + "total": "10.0000000" + } + } + ], + "receiver_wallets_counters": { + "draft": 1, + "ready": 0, + "registered": 0, + "flagged": 0, + "total": 1 + }, + "total_receivers": 1 + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + }) +} diff --git a/internal/serve/httphandler/stellar_toml_handler.go b/internal/serve/httphandler/stellar_toml_handler.go new file mode 100644 index 000000000..da5bad551 --- /dev/null +++ b/internal/serve/httphandler/stellar_toml_handler.go @@ -0,0 +1,98 @@ +package httphandler + +import ( + "fmt" + "net/http" + "strings" + + "github.com/stellar/go/network" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" +) + +type StellarTomlHandler struct { + AnchorPlatformBaseSepURL string + DistributionPublicKey string + NetworkPassphrase string + Models *data.Models + Sep10SigningPublicKey string +} + +const ( + horizonPubnetURL = "https://horizon.stellar.org" + horizonTestnetURL = "https://horizon-testnet.stellar.org" +) + +func (s *StellarTomlHandler) horizonURL() string { + if s.NetworkPassphrase == network.PublicNetworkPassphrase { + return horizonPubnetURL + } + return horizonTestnetURL +} + +// buildGeneralInformation will create the general informations based on the env vars injected into the handler. +func (s *StellarTomlHandler) buildGeneralInformation() string { + webAuthEndpoint := s.AnchorPlatformBaseSepURL + "/auth" + transferServerSep0024 := s.AnchorPlatformBaseSepURL + "/sep24" + accounts := fmt.Sprintf("[%q, %q]", s.DistributionPublicKey, s.Sep10SigningPublicKey) + + return fmt.Sprintf(` + ACCOUNTS=%s + SIGNING_KEY=%q + NETWORK_PASSPHRASE=%q + HORIZON_URL=%q + WEB_AUTH_ENDPOINT=%q + TRANSFER_SERVER_SEP0024=%q + `, accounts, s.Sep10SigningPublicKey, s.NetworkPassphrase, s.horizonURL(), webAuthEndpoint, transferServerSep0024) +} + +func (s *StellarTomlHandler) buildOrganizationDocumentation(organization data.Organization) string { + return fmt.Sprintf(` + [DOCUMENTATION] + ORG_NAME=%q + `, organization.Name) +} + +// buildCurrencyInformation will create the currency information for all assets register in the application. +func (s *StellarTomlHandler) buildCurrencyInformation(assets []data.Asset) string { + strAssets := "" + for _, asset := range assets { + strAssets += fmt.Sprintf(` + [[CURRENCIES]] + code = %q + issuer = %q + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "%s" + `, asset.Code, asset.Issuer, asset.Code) + } + + return strAssets +} + +// ServeHTTP will serve the stellar.toml file needed to register users through SEP-24. +func (s StellarTomlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + assets, err := s.Models.Assets.GetAll(r.Context()) + ctx := r.Context() + if err != nil { + httperror.InternalError(ctx, "Cannot retrieve assets", err, nil).Render(w) + return + } + + organization, err := s.Models.Organizations.Get(r.Context()) + if err != nil { + httperror.InternalError(ctx, "Cannot retrieve organization", err, nil).Render(w) + return + } + + stellarToml := s.buildGeneralInformation() + s.buildOrganizationDocumentation(*organization) + s.buildCurrencyInformation(assets) + stellarToml = strings.TrimSpace(stellarToml) + stellarToml = strings.ReplaceAll(stellarToml, "\t", "") + + _, err = fmt.Fprint(w, stellarToml) + if err != nil { + httperror.InternalError(ctx, "Cannot write stellar.toml content", err, nil).Render(w) + return + } +} diff --git a/internal/serve/httphandler/stellar_toml_handler_test.go b/internal/serve/httphandler/stellar_toml_handler_test.go new file mode 100644 index 000000000..83f06413b --- /dev/null +++ b/internal/serve/httphandler/stellar_toml_handler_test.go @@ -0,0 +1,341 @@ +package httphandler + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/network" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_StellarTomlHandler_horizonURL(t *testing.T) { + testCases := []struct { + name string + s StellarTomlHandler + want string + }{ + { + name: "pubnet", + s: StellarTomlHandler{NetworkPassphrase: network.PublicNetworkPassphrase}, + want: horizonPubnetURL, + }, + { + name: "testnet", + s: StellarTomlHandler{NetworkPassphrase: network.TestNetworkPassphrase}, + want: horizonTestnetURL, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.s.horizonURL(); got != tc.want { + t.Errorf("StellarTomlHandler.horizonURL() = %v, want %v", got, tc.want) + } + }) + } +} + +func Test_StellarTomlHandler_buildGeneralInformation(t *testing.T) { + testCases := []struct { + name string + s StellarTomlHandler + want string + }{ + { + name: "pubnet", + s: StellarTomlHandler{ + DistributionPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + NetworkPassphrase: network.PublicNetworkPassphrase, + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + AnchorPlatformBaseSepURL: "https://anchor-platform-domain", + }, + want: fmt.Sprintf(` + ACCOUNTS=["GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"] + SIGNING_KEY="GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + NETWORK_PASSPHRASE=%q + HORIZON_URL=%q + WEB_AUTH_ENDPOINT="https://anchor-platform-domain/auth" + TRANSFER_SERVER_SEP0024="https://anchor-platform-domain/sep24" + `, network.PublicNetworkPassphrase, horizonPubnetURL), + }, + { + name: "testnet", + s: StellarTomlHandler{ + DistributionPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + NetworkPassphrase: network.TestNetworkPassphrase, + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + AnchorPlatformBaseSepURL: "https://anchor-platform-domain", + }, + want: fmt.Sprintf(` + ACCOUNTS=["GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"] + SIGNING_KEY="GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + NETWORK_PASSPHRASE=%q + HORIZON_URL=%q + WEB_AUTH_ENDPOINT="https://anchor-platform-domain/auth" + TRANSFER_SERVER_SEP0024="https://anchor-platform-domain/sep24" + `, network.TestNetworkPassphrase, horizonTestnetURL), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + genaralInformation := tc.s.buildGeneralInformation() + assert.Equal(t, tc.want, genaralInformation) + }) + } +} + +func Test_StellarTomlHandler_buildOrganizationDocumentation(t *testing.T) { + stellarTomlHandler := StellarTomlHandler{} + testCases := []struct { + name string + organization data.Organization + want string + }{ + { + name: "FOO Org", + organization: data.Organization{ + Name: "FOO Org", + }, + want: ` + [DOCUMENTATION] + ORG_NAME="FOO Org" + `, + }, + { + name: "BAR Org", + organization: data.Organization{ + Name: "BAR Org", + }, + want: ` + [DOCUMENTATION] + ORG_NAME="BAR Org" + `, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + genaralInformation := stellarTomlHandler.buildOrganizationDocumentation(tc.organization) + assert.Equal(t, tc.want, genaralInformation) + }) + } +} + +func Test_StellarTomlHandler_buildCurrencyInformation(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + s := StellarTomlHandler{} + + t.Run("build currency information without assets", func(t *testing.T) { + currencyInformation := s.buildCurrencyInformation([]data.Asset{}) + assert.Empty(t, currencyInformation) + }) + + t.Run("build currency information with asset", func(t *testing.T) { + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + + currencyInformation := s.buildCurrencyInformation([]data.Asset{*asset}) + wantStr := ` + [[CURRENCIES]] + code = "USDC" + issuer = "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "USDC" + ` + + assert.Equal(t, wantStr, currencyInformation) + }) + + t.Run("build currency information with multiple assets", func(t *testing.T) { + assets := data.ClearAndCreateAssetFixtures(t, ctx, dbConnectionPool) + + currencyInformation := s.buildCurrencyInformation(assets) + wantStr := ` + [[CURRENCIES]] + code = "EURT" + issuer = "GA62MH5RDXFWAIWHQEFNMO2SVDDCQLWOO3GO36VQB5LHUXL22DQ6IQAU" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "EURT" + + [[CURRENCIES]] + code = "USDC" + issuer = "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "USDC" + ` + + assert.Equal(t, wantStr, currencyInformation) + }) +} + +func Test_StellarTomlHandler_ServeHTTP(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + data.ClearAndCreateAssetFixtures(t, ctx, dbConnectionPool) + + t.Run("build testnet toml", func(t *testing.T) { + tomlHandler := StellarTomlHandler{ + DistributionPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + NetworkPassphrase: network.TestNetworkPassphrase, + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + AnchorPlatformBaseSepURL: "https://anchor-platform-domain", + Models: models, + } + + r := chi.NewRouter() + r.Get("/.well-known/stellar.toml", tomlHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/.well-known/stellar.toml", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + wantToml := fmt.Sprintf(` + ACCOUNTS=["GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"] + SIGNING_KEY="GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + NETWORK_PASSPHRASE=%q + HORIZON_URL=%q + WEB_AUTH_ENDPOINT="https://anchor-platform-domain/auth" + TRANSFER_SERVER_SEP0024="https://anchor-platform-domain/sep24" + + [DOCUMENTATION] + ORG_NAME="MyCustomAid" + + [[CURRENCIES]] + code = "EURT" + issuer = "GA62MH5RDXFWAIWHQEFNMO2SVDDCQLWOO3GO36VQB5LHUXL22DQ6IQAU" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "EURT" + + [[CURRENCIES]] + code = "USDC" + issuer = "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "USDC" + `, network.TestNetworkPassphrase, horizonTestnetURL) + wantToml = strings.TrimSpace(wantToml) + wantToml = strings.ReplaceAll(wantToml, "\t", "") + assert.Equal(t, wantToml, rr.Body.String()) + }) + + t.Run("build pubnet toml", func(t *testing.T) { + tomlHandler := StellarTomlHandler{ + DistributionPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + NetworkPassphrase: network.PublicNetworkPassphrase, + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + AnchorPlatformBaseSepURL: "https://anchor-platform-domain", + Models: models, + } + + r := chi.NewRouter() + r.Get("/.well-known/stellar.toml", tomlHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/.well-known/stellar.toml", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + wantToml := fmt.Sprintf(` + ACCOUNTS=["GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"] + SIGNING_KEY="GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + NETWORK_PASSPHRASE=%q + HORIZON_URL=%q + WEB_AUTH_ENDPOINT="https://anchor-platform-domain/auth" + TRANSFER_SERVER_SEP0024="https://anchor-platform-domain/sep24" + + [DOCUMENTATION] + ORG_NAME="MyCustomAid" + + [[CURRENCIES]] + code = "EURT" + issuer = "GA62MH5RDXFWAIWHQEFNMO2SVDDCQLWOO3GO36VQB5LHUXL22DQ6IQAU" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "EURT" + + [[CURRENCIES]] + code = "USDC" + issuer = "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE" + is_asset_anchored = true + anchor_asset_type = "fiat" + status = "live" + desc = "USDC" + `, network.PublicNetworkPassphrase, horizonPubnetURL) + wantToml = strings.TrimSpace(wantToml) + wantToml = strings.ReplaceAll(wantToml, "\t", "") + assert.Equal(t, wantToml, rr.Body.String()) + }) + + t.Run("build toml without assets in database", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + tomlHandler := StellarTomlHandler{ + DistributionPublicKey: "GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", + NetworkPassphrase: network.PublicNetworkPassphrase, + Sep10SigningPublicKey: "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S", + AnchorPlatformBaseSepURL: "https://anchor-platform-domain", + Models: models, + } + + r := chi.NewRouter() + r.Get("/.well-known/stellar.toml", tomlHandler.ServeHTTP) + + req, err := http.NewRequest("GET", "/.well-known/stellar.toml", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + wantToml := fmt.Sprintf(` + ACCOUNTS=["GBC2HVWFIFN7WJHFORVBCDKJORG6LWTW3O2QBHOURL3KHZPM4KMWTUSA", "GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S"] + SIGNING_KEY="GAX46JJZ3NPUM2EUBTTGFM6ITDF7IGAFNBSVWDONPYZJREHFPP2I5U7S" + NETWORK_PASSPHRASE=%q + HORIZON_URL=%q + WEB_AUTH_ENDPOINT="https://anchor-platform-domain/auth" + TRANSFER_SERVER_SEP0024="https://anchor-platform-domain/sep24" + + [DOCUMENTATION] + ORG_NAME="MyCustomAid" + `, network.PublicNetworkPassphrase, horizonPubnetURL) + wantToml = strings.TrimSpace(wantToml) + wantToml = strings.ReplaceAll(wantToml, "\t", "") + assert.Equal(t, wantToml, rr.Body.String()) + }) +} diff --git a/internal/serve/httphandler/update_receiver_handler.go b/internal/serve/httphandler/update_receiver_handler.go new file mode 100644 index 000000000..eabbbca72 --- /dev/null +++ b/internal/serve/httphandler/update_receiver_handler.go @@ -0,0 +1,112 @@ +package httphandler + +import ( + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" +) + +type UpdateReceiverHandler struct { + Models *data.Models + DBConnectionPool db.DBConnectionPool +} + +func createVerificationInsert(updateReceiverInfo *validators.UpdateReceiverRequest, receiverID string) []data.ReceiverVerificationInsert { + receiverVerifications := []data.ReceiverVerificationInsert{} + + if updateReceiverInfo.DateOfBirth != "" { + receiverVerifications = append(receiverVerifications, data.ReceiverVerificationInsert{ + ReceiverID: receiverID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: updateReceiverInfo.DateOfBirth, + }) + } + + if updateReceiverInfo.Pin != "" { + receiverVerifications = append(receiverVerifications, data.ReceiverVerificationInsert{ + ReceiverID: receiverID, + VerificationField: data.VerificationFieldPin, + VerificationValue: updateReceiverInfo.Pin, + }) + } + + if updateReceiverInfo.NationalID != "" { + receiverVerifications = append(receiverVerifications, data.ReceiverVerificationInsert{ + ReceiverID: receiverID, + VerificationField: data.VerificationFieldNationalID, + VerificationValue: updateReceiverInfo.NationalID, + }) + } + + return receiverVerifications +} + +func (h UpdateReceiverHandler) UpdateReceiver(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + var reqBody validators.UpdateReceiverRequest + err := httpdecode.DecodeJSON(req, &reqBody) + if err != nil { + err = fmt.Errorf("decoding the request body: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + // validate request payload + validator := validators.NewUpdateReceiverValidator() + validator.ValidateReceiver(&reqBody) + if validator.HasErrors() { + log.Ctx(ctx).Errorf("request invalid: %s", validator.Errors) + httperror.BadRequest("request invalid", nil, validator.Errors).Render(rw) + return + } + + receiverID := chi.URLParam(req, "id") + receiverVerifications := createVerificationInsert(&reqBody, receiverID) + receiver, err := db.RunInTransactionWithResult(ctx, h.DBConnectionPool, nil, func(dbTx db.DBTransaction) (response *data.Receiver, innerErr error) { + for _, rv := range receiverVerifications { + innerErr = h.Models.ReceiverVerification.UpdateVerificationValue( + req.Context(), + h.Models.DBConnectionPool, + rv.ReceiverID, + rv.VerificationField, + rv.VerificationValue, + ) + + if innerErr != nil { + return nil, fmt.Errorf("error updating receiver verification %s: %w", rv.VerificationField, innerErr) + } + } + + receiverUpdate := data.ReceiverUpdate{ + Email: reqBody.Email, + ExternalId: reqBody.ExternalID, + } + if receiverUpdate.Email != "" || receiverUpdate.ExternalId != "" { + if innerErr = h.Models.Receiver.Update(ctx, dbTx, receiverID, receiverUpdate); innerErr != nil { + return nil, fmt.Errorf("error updating receiver with ID %s: %w", receiverID, innerErr) + } + } + + receiver, innerErr := h.Models.Receiver.Get(ctx, h.Models.DBConnectionPool, receiverID) + if innerErr != nil { + return nil, fmt.Errorf("error querying receiver with ID %s: %w", receiverID, innerErr) + } + + return receiver, nil + }) + if err != nil { + httperror.InternalError(ctx, "", err, nil).Render(rw) + } + + httpjson.RenderStatus(rw, http.StatusOK, receiver, httpjson.JSON) +} diff --git a/internal/serve/httphandler/update_receiver_handler_test.go b/internal/serve/httphandler/update_receiver_handler_test.go new file mode 100644 index 000000000..8fb14dd6c --- /dev/null +++ b/internal/serve/httphandler/update_receiver_handler_test.go @@ -0,0 +1,478 @@ +package httphandler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_UpdateReceiverHandler_createVerificationInsert(t *testing.T) { + receiverID := "mock_id" + + verificationDOB := data.ReceiverVerificationInsert{ + ReceiverID: receiverID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1999-01-01", + } + + verificationPIN := data.ReceiverVerificationInsert{ + ReceiverID: receiverID, + VerificationField: data.VerificationFieldPin, + VerificationValue: "123", + } + + verificationNationalID := data.ReceiverVerificationInsert{ + ReceiverID: receiverID, + VerificationField: data.VerificationFieldNationalID, + VerificationValue: "12345CODE", + } + + testCases := []struct { + name string + updateReceiverRequest validators.UpdateReceiverRequest + want []data.ReceiverVerificationInsert + }{ + { + name: "empty update request", + updateReceiverRequest: validators.UpdateReceiverRequest{}, + want: []data.ReceiverVerificationInsert{}, + }, + { + name: "insert receiver verification date of birth", + updateReceiverRequest: validators.UpdateReceiverRequest{DateOfBirth: "1999-01-01"}, + want: []data.ReceiverVerificationInsert{verificationDOB}, + }, + { + name: "insert receiver verification pin", + updateReceiverRequest: validators.UpdateReceiverRequest{Pin: "123"}, + want: []data.ReceiverVerificationInsert{verificationPIN}, + }, + { + name: "insert receiver verification national ID", + updateReceiverRequest: validators.UpdateReceiverRequest{NationalID: "12345CODE"}, + want: []data.ReceiverVerificationInsert{verificationNationalID}, + }, + { + name: "insert multipes receiver verification values", + updateReceiverRequest: validators.UpdateReceiverRequest{ + DateOfBirth: "1999-01-01", + Pin: "123", + NationalID: "12345CODE", + }, + want: []data.ReceiverVerificationInsert{verificationDOB, verificationPIN, verificationNationalID}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + updateReceiverRequest := tc.updateReceiverRequest + receiverVerifications := createVerificationInsert(&updateReceiverRequest, receiverID) + + assert.Equal(t, tc.want, receiverVerifications) + }) + } +} + +func Test_UpdateReceiverHandler(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + handler := &UpdateReceiverHandler{ + Models: models, + DBConnectionPool: dbConnectionPool, + } + + ctx := context.Background() + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + Email: &[]string{"receiver@email.com"}[0], + ExternalID: "externalID", + }) + + // setup + r := chi.NewRouter() + r.Patch("/receivers/{id}", handler.UpdateReceiver) + + t.Run("error invalid request body", func(t *testing.T) { + testCases := []struct { + name string + request validators.UpdateReceiverRequest + want string + }{ + { + name: "empty request body", + request: validators.UpdateReceiverRequest{}, + want: ` + { + "error": "request invalid", + "extras": { + "body": "request body is empty" + } + } + `, + }, + { + name: "invalid date of birth", + request: validators.UpdateReceiverRequest{DateOfBirth: "invalid"}, + want: ` + { + "error": "request invalid", + "extras": { + "date_of_birth": "invalid date of birth format. Correct format: 1990-01-30" + } + } + `, + }, + { + name: "invalid pin", + request: validators.UpdateReceiverRequest{Pin: " "}, + want: ` + { + "error": "request invalid", + "extras": { + "pin": "invalid pin format" + } + } + `, + }, + { + name: "invalid national ID", + request: validators.UpdateReceiverRequest{NationalID: " "}, + want: ` + { + "error": "request invalid", + "extras": { + "national_id": "invalid national ID format" + } + } + `, + }, + { + name: "invalid email", + request: validators.UpdateReceiverRequest{Email: "invalid"}, + want: ` + { + "error": "request invalid", + "extras": { + "email": "invalid email format" + } + } + `, + }, + { + name: "invalid external ID", + request: validators.UpdateReceiverRequest{ExternalID: " "}, + want: ` + { + "error": "request invalid", + "extras": { + "external_id": "invalid external_id format" + } + } + `, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(tc.request) + require.NoError(t, err) + req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, tc.want, string(respBody)) + }) + } + }) + + t.Run("update date of birth value", func(t *testing.T) { + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "2000-01-01", + }) + + request := validators.UpdateReceiverRequest{DateOfBirth: "1999-01-01"} + + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + query := ` + SELECT + hashed_value + FROM + receiver_verifications + WHERE + receiver_id = $1 AND + verification_field = $2 + ` + + newReceiverVerification := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldDateOfBirth) + require.NoError(t, err) + + assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "1999-01-01")) + assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "2000-01-01")) + + receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, "receiver@email.com", *receiverDB.Email) + assert.Equal(t, "externalID", receiverDB.ExternalID) + }) + + t.Run("update pin value", func(t *testing.T) { + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldPin, + VerificationValue: "890", + }) + + request := validators.UpdateReceiverRequest{Pin: "123"} + + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + query := ` + SELECT + hashed_value + FROM + receiver_verifications + WHERE + receiver_id = $1 AND + verification_field = $2 + ` + + newReceiverVerification := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldPin) + require.NoError(t, err) + + assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "123")) + assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "890")) + + receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, "receiver@email.com", *receiverDB.Email) + assert.Equal(t, "externalID", receiverDB.ExternalID) + }) + + t.Run("update national ID value", func(t *testing.T) { + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldNationalID, + VerificationValue: "OLDID890", + }) + + request := validators.UpdateReceiverRequest{NationalID: "NEWID123"} + + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + query := ` + SELECT + hashed_value + FROM + receiver_verifications + WHERE + receiver_id = $1 AND + verification_field = $2 + ` + + newReceiverVerification := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, data.VerificationFieldNationalID) + require.NoError(t, err) + + assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "NEWID123")) + assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, "OLDID890")) + + receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, "receiver@email.com", *receiverDB.Email) + assert.Equal(t, "externalID", receiverDB.ExternalID) + }) + + t.Run("update multiples receiver verifications values", func(t *testing.T) { + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "2000-01-01", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldPin, + VerificationValue: "890", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldNationalID, + VerificationValue: "OLDID890", + }) + + request := validators.UpdateReceiverRequest{ + DateOfBirth: "1999-01-01", + Pin: "123", + NationalID: "NEWID123", + } + + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("PATCH", route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + query := ` + SELECT + hashed_value + FROM + receiver_verifications + WHERE + receiver_id = $1 AND + verification_field = $2 + ` + + receiverVerifications := []struct { + verificationField data.VerificationField + newVerificationValue string + oldVerificationValue string + }{ + { + verificationField: data.VerificationFieldDateOfBirth, + newVerificationValue: "1999-01-01", + oldVerificationValue: "2000-01-01", + }, + { + verificationField: data.VerificationFieldPin, + newVerificationValue: "123", + oldVerificationValue: "890", + }, + { + verificationField: data.VerificationFieldNationalID, + newVerificationValue: "NEWID123", + oldVerificationValue: "OLDID890", + }, + } + for _, v := range receiverVerifications { + newReceiverVerification := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &newReceiverVerification, query, receiver.ID, v.verificationField) + require.NoError(t, err) + + assert.True(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, v.newVerificationValue)) + assert.False(t, data.CompareVerificationValue(newReceiverVerification.HashedValue, v.oldVerificationValue)) + + receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, "receiver@email.com", *receiverDB.Email) + assert.Equal(t, "externalID", receiverDB.ExternalID) + } + }) + + t.Run("updates receiver's email", func(t *testing.T) { + request := validators.UpdateReceiverRequest{ + Email: "update_receiver@email.com", + } + + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(request) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + assert.Equal(t, "update_receiver@email.com", *receiverDB.Email) + }) + + t.Run("updates receiver's external ID", func(t *testing.T) { + request := validators.UpdateReceiverRequest{ + ExternalID: "newExternalID", + } + + route := fmt.Sprintf("/receivers/%s", receiver.ID) + reqBody, err := json.Marshal(request) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPatch, route, strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + receiverDB, err := models.Receiver.Get(ctx, dbConnectionPool, receiver.ID) + require.NoError(t, err) + + assert.Equal(t, "newExternalID", receiverDB.ExternalID) + }) +} diff --git a/internal/serve/httphandler/user_handler.go b/internal/serve/httphandler/user_handler.go new file mode 100644 index 000000000..d06579bd2 --- /dev/null +++ b/internal/serve/httphandler/user_handler.go @@ -0,0 +1,285 @@ +package httphandler + +import ( + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/stellar/go/support/http/httpdecode" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +const invitationMessageTitle = "Welcome to Stellar Disbursement Platform" + +type UserHandler struct { + AuthManager auth.AuthManager + MessengerClient message.MessengerClient + UIBaseURL string + Models *data.Models +} + +type UserActivationRequest struct { + UserID string `json:"user_id"` + IsActive *bool `json:"is_active"` +} + +func (uar UserActivationRequest) validate() *httperror.HTTPError { + validator := validators.NewValidator() + + validator.Check(uar.UserID != "", "user_id", "user_id is required") + validator.Check(uar.IsActive != nil, "is_active", "is_active is required") + + if validator.HasErrors() { + return httperror.BadRequest("Request invalid", nil, validator.Errors) + } + + return nil +} + +type CreateUserRequest struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + Roles []data.UserRole `json:"roles"` +} + +func (cur CreateUserRequest) validate() *httperror.HTTPError { + validator := validators.NewValidator() + + validator.Check(cur.FirstName != "", "fist_name", "fist_name is required") + validator.Check(cur.LastName != "", "last_name", "last_name is required") + validator.Check(cur.Email != "", "email", "email is required") + validateRoles(validator, cur.Roles) + + if validator.HasErrors() { + return httperror.BadRequest("Request invalid", nil, validator.Errors) + } + + return nil +} + +type UpdateRolesRequest struct { + UserID string `json:"user_id"` + Roles []data.UserRole `json:"roles"` +} + +func (upr UpdateRolesRequest) validate() *httperror.HTTPError { + validator := validators.NewValidator() + + validator.Check(upr.UserID != "", "user_id", "user_id is required") + validateRoles(validator, upr.Roles) + + if validator.HasErrors() { + return httperror.BadRequest("Request invalid", nil, validator.Errors) + } + + return nil +} + +func validateRoles(validator *validators.Validator, roles []data.UserRole) { + // NOTE: in the MVP, users should have only one role. + validator.Check(len(roles) == 1, "roles", "the number of roles required is exactly one") + + // Validating the role of the request is a valid value + if _, ok := validator.Errors["roles"]; !ok { + role := roles[0] + validator.Check(role.IsValid(), "roles", fmt.Sprintf("unexpected value for roles[0]=%s. Expect one of these values: %s", role, data.GetAllRoles())) + } +} + +func (h UserHandler) UserActivation(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + log.Ctx(ctx).Warn("token not found when updating user activation") + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + var reqBody UserActivationRequest + if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { + err = fmt.Errorf("decoding the request body: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + if err := reqBody.validate(); err != nil { + err.Render(rw) + return + } + + var activationErr error + if *reqBody.IsActive { + activationErr = h.AuthManager.ActivateUser(ctx, token, reqBody.UserID) + } else { + activationErr = h.AuthManager.DeactivateUser(ctx, token, reqBody.UserID) + } + + if activationErr != nil { + if errors.Is(activationErr, auth.ErrInvalidToken) { + httperror.Unauthorized("", activationErr, nil).Render(rw) + return + } + + if errors.Is(activationErr, auth.ErrNoRowsAffected) { + httperror.BadRequest("", activationErr, map[string]interface{}{"user_id": "user_id is invalid"}).Render(rw) + return + } + + httperror.InternalError(ctx, "Cannot update user activation", activationErr, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"message": "user activation was updated successfully"}, httpjson.JSON) +} + +func (h UserHandler) CreateUser(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + var reqBody CreateUserRequest + if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { + err = fmt.Errorf("decoding the request body: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + if err := reqBody.validate(); err != nil { + err.Render(rw) + return + } + + newUser := &auth.User{ + FirstName: reqBody.FirstName, + LastName: reqBody.LastName, + Email: reqBody.Email, + Roles: data.FromUserRoleArrayToStringArray(reqBody.Roles), + } + + // The password is empty so the AuthManager will generate one automatically. + u, err := h.AuthManager.CreateUser(ctx, newUser, "") + if err != nil { + if errors.Is(err, auth.ErrUserEmailAlreadyExists) { + httperror.BadRequest(auth.ErrUserEmailAlreadyExists.Error(), err, nil).Render(rw) + return + } + + httperror.InternalError(ctx, "Cannot create user", err, nil).Render(rw) + return + } + + organization, err := h.Models.Organizations.Get(ctx) + if err != nil { + httperror.InternalError(ctx, "Cannot get organization data", err, nil).Render(rw) + return + } + + forgotPasswordLink, err := url.JoinPath(h.UIBaseURL, "forgot-password") + if err != nil { + httperror.InternalError(ctx, "Cannot get forgot password link", err, nil).Render(rw) + return + } + + invitationMsgData := htmltemplate.InvitationMessageTemplate{ + FirstName: u.FirstName, + Role: u.Roles[0], + ForgotPasswordLink: forgotPasswordLink, + OrganizationName: organization.Name, + } + messageContent, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(invitationMsgData) + if err != nil { + httperror.InternalError(ctx, "Cannot execute invitation message template", err, nil).Render(rw) + return + } + + msg := message.Message{ + ToEmail: u.Email, + Message: messageContent, + Title: invitationMessageTitle, + } + err = h.MessengerClient.SendMessage(msg) + if err != nil { + msg := fmt.Sprintf("Cannot send invitation email for user %s", u.ID) + httperror.InternalError(ctx, msg, err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusCreated, u, httpjson.JSON) +} + +func (h UserHandler) UpdateUserRoles(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + log.Ctx(ctx).Warn("token not found when updating user roles") + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + var reqBody UpdateRolesRequest + if err := httpdecode.DecodeJSON(req, &reqBody); err != nil { + err = fmt.Errorf("decoding the request body: %w", err) + log.Ctx(ctx).Error(err) + httperror.BadRequest("", err, nil).Render(rw) + return + } + + if err := reqBody.validate(); err != nil { + err.Render(rw) + return + } + + updateUserRolesErr := h.AuthManager.UpdateUserRoles(ctx, token, reqBody.UserID, data.FromUserRoleArrayToStringArray(reqBody.Roles)) + if updateUserRolesErr != nil { + if errors.Is(updateUserRolesErr, auth.ErrInvalidToken) { + httperror.Unauthorized("", updateUserRolesErr, nil).Render(rw) + return + } + + if errors.Is(updateUserRolesErr, auth.ErrNoRowsAffected) { + httperror.BadRequest("", updateUserRolesErr, map[string]interface{}{"user_id": "user_id is invalid"}).Render(rw) + return + } + + httperror.InternalError(ctx, "Cannot update user activation", updateUserRolesErr, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, map[string]string{"message": "user roles were updated successfully"}, httpjson.JSON) +} + +func (h UserHandler) GetAllUsers(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + log.Ctx(ctx).Warn("token not found when getting all users") + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + users, err := h.AuthManager.GetAllUsers(ctx, token) + if err != nil { + if errors.Is(err, auth.ErrInvalidToken) { + httperror.Unauthorized("", err, nil).Render(rw) + return + } + + httperror.InternalError(ctx, "Cannot get all users", err, nil).Render(rw) + return + } + + httpjson.RenderStatus(rw, http.StatusOK, users, httpjson.JSON) +} diff --git a/internal/serve/httphandler/user_handler_test.go b/internal/serve/httphandler/user_handler_test.go new file mode 100644 index 000000000..0f193f291 --- /dev/null +++ b/internal/serve/httphandler/user_handler_test.go @@ -0,0 +1,1342 @@ +package httphandler + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + urllib "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/htmltemplate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_UserHandler_UserActivation(t *testing.T) { + r := chi.NewRouter() + + authenticatorMock := &auth.AuthenticatorMock{} + jwtManagerMock := &auth.JWTManagerMock{} + roleManagerMock := &auth.RoleManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomAuthenticatorOption(authenticatorMock), + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomRoleManagerOption(roleManagerMock), + ) + + handler := &UserHandler{AuthManager: authManager} + + const url = "/users/activation" + + r.Patch(url, handler.UserActivation) + + t.Run("returns Unauthorized when no token is in the request context", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns error when request body is invalid", func(t *testing.T) { + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, "mytoken") + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{}`)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Request invalid", + "extras": { + "user_id": "user_id is required", + "is_active": "is_active is required" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{"user_id": "user-id"}`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "is_active": "is_active is required" + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{"is_active": true}`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "user_id": "user_id is required" + } + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`"invalid"`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "The request was invalid in some way." + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + assert.Contains(t, buf.String(), "decoding the request body") + }) + + t.Run("returns Unauthorized when token is invalid", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(false, nil). + Twice() + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + // Activating the user + reqBody := ` + { + "user_id": "user-id", + "is_active": true + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + + // Deactivating the user + reqBody = ` + { + "user_id": "user-id", + "is_active": false + } + ` + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns BadRequest when user doesn't exist", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(true, nil). + Twice() + + authenticatorMock. + On("ActivateUser", mock.Anything, "user-id"). + Return(auth.ErrNoRowsAffected). + Once(). + On("DeactivateUser", mock.Anything, "user-id"). + Return(auth.ErrNoRowsAffected). + Once() + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + // Activating the user + reqBody := ` + { + "user_id": "user-id", + "is_active": true + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"user_id":"user_id is invalid"}}`, string(respBody)) + + // Deactivating the user + reqBody = ` + { + "user_id": "user-id", + "is_active": false + } + ` + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"user_id":"user_id is invalid"}}`, string(respBody)) + }) + + t.Run("returns InternalServerError when a unexpected error occurs", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(false, errors.New("unexpected error")). + Once() + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + reqBody := ` + { + "user_id": "user-id", + "is_active": true + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot update user activation"}`, string(respBody)) + assert.Contains(t, buf.String(), "Cannot update user activation") + }) + + t.Run("updates the user activation correctly", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(true, nil). + Twice() + + authenticatorMock. + On("ActivateUser", mock.Anything, "user-id"). + Return(nil). + Once(). + On("DeactivateUser", mock.Anything, "user-id"). + Return(nil). + Once() + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + // Activating the user + reqBody := ` + { + "user_id": "user-id", + "is_active": true + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "user activation was updated successfully"}`, string(respBody)) + + // Deactivating the user + reqBody = ` + { + "user_id": "user-id", + "is_active": false + } + ` + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "user activation was updated successfully"}`, string(respBody)) + }) +} + +func Test_CreateUserRequest_validate(t *testing.T) { + cur := CreateUserRequest{ + FirstName: "", + LastName: "", + Email: "", + Roles: []data.UserRole{}, + } + + extras := map[string]interface{}{ + "email": "email is required", + "fist_name": "fist_name is required", + "last_name": "last_name is required", + "roles": "the number of roles required is exactly one", + } + expectedErr := httperror.BadRequest("Request invalid", nil, extras) + + err := cur.validate() + assert.Equal(t, expectedErr, err) + + cur = CreateUserRequest{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []data.UserRole{data.BusinessUserRole, data.DeveloperUserRole}, + } + + extras = map[string]interface{}{ + "roles": "the number of roles required is exactly one", + } + expectedErr = httperror.BadRequest("Request invalid", nil, extras) + + err = cur.validate() + assert.Equal(t, expectedErr, err) + + cur = CreateUserRequest{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []data.UserRole{data.DeveloperUserRole}, + } + + err = cur.validate() + assert.Nil(t, err) +} + +func Test_UserHandler_CreateUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + r := chi.NewRouter() + + authenticatorMock := &auth.AuthenticatorMock{} + authManager := auth.NewAuthManager(auth.WithCustomAuthenticatorOption(authenticatorMock)) + + messengerClientMock := &message.MessengerClientMock{} + uiBaseURL := "https://sdp.com" + handler := &UserHandler{ + AuthManager: authManager, + MessengerClient: messengerClientMock, + UIBaseURL: uiBaseURL, + Models: models, + } + + const url = "/users" + + r.Post(url, handler.CreateUser) + + t.Run("returns error when request body is invalid", func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{}`)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Request invalid", + "extras": { + "email": "email is required", + "fist_name": "fist_name is required", + "last_name": "last_name is required", + "roles": "the number of roles required is exactly one" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + body := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["role1", "role2"] + } + ` + req, err = http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "roles": "the number of roles required is exactly one" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + body = ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["role1"] + } + ` + req, err = http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "roles": "unexpected value for roles[0]=role1. Expect one of these values: [owner financial_controller developer business]" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err = http.NewRequest(http.MethodPost, url, strings.NewReader(`"invalid"`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "The request was invalid in some way." + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + assert.Contains(t, buf.String(), "decoding the request body") + }) + + t.Run("returns error when Auth Manager fails", func(t *testing.T) { + u := &auth.User{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + authenticatorMock. + On("CreateUser", mock.Anything, u, ""). + Return(nil, errors.New("unexpected error")). + Once() + + body := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["developer"] + } + ` + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Cannot create user" + } + ` + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("returns Bad Request when user is duplicated", func(t *testing.T) { + u := &auth.User{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + authenticatorMock. + On("CreateUser", mock.Anything, u, ""). + Return(nil, auth.ErrUserEmailAlreadyExists). + Once() + + body := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["developer"] + } + ` + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "a user with this email already exists"}`, string(respBody)) + }) + + t.Run("returns error when sending email fails", func(t *testing.T) { + u := &auth.User{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + expectedUser := &auth.User{ + ID: "user-id", + FirstName: u.FirstName, + LastName: u.LastName, + Email: u.Email, + Roles: u.Roles, + } + + authenticatorMock. + On("CreateUser", mock.Anything, u, ""). + Return(expectedUser, nil). + Once() + + forgotPasswordLink, err := urllib.JoinPath(uiBaseURL, "forgot-password") + require.NoError(t, err) + + content, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(htmltemplate.InvitationMessageTemplate{ + FirstName: u.FirstName, + Role: u.Roles[0], + ForgotPasswordLink: forgotPasswordLink, + OrganizationName: "MyCustomAid", + }) + require.NoError(t, err) + + msg := message.Message{ + ToEmail: u.Email, + Title: invitationMessageTitle, + Message: content, + } + messengerClientMock. + On("SendMessage", msg). + Return(errors.New("unexpected error")). + Once() + + body := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["developer"] + } + ` + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Cannot send invitation email for user user-id" + } + ` + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("returns error when joining the forgot password link", func(t *testing.T) { + u := &auth.User{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + expectedUser := &auth.User{ + ID: "user-id", + FirstName: u.FirstName, + LastName: u.LastName, + Email: u.Email, + Roles: u.Roles, + } + + authenticatorMock. + On("CreateUser", mock.Anything, u, ""). + Return(expectedUser, nil). + Once() + + body := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["developer"] + } + ` + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + http.HandlerFunc(UserHandler{ + AuthManager: authManager, + MessengerClient: messengerClientMock, + UIBaseURL: "%invalid%", + Models: models, + }.CreateUser).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Cannot get forgot password link" + } + ` + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + t.Run("creates user successfully", func(t *testing.T) { + u := &auth.User{ + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + expectedUser := &auth.User{ + ID: "user-id", + FirstName: u.FirstName, + LastName: u.LastName, + Email: u.Email, + Roles: u.Roles, + IsActive: true, + } + + authenticatorMock. + On("CreateUser", mock.Anything, u, ""). + Return(expectedUser, nil). + Once() + + forgotPasswordLink, err := urllib.JoinPath(uiBaseURL, "forgot-password") + require.NoError(t, err) + + content, err := htmltemplate.ExecuteHTMLTemplateForInvitationMessage(htmltemplate.InvitationMessageTemplate{ + FirstName: u.FirstName, + Role: u.Roles[0], + ForgotPasswordLink: forgotPasswordLink, + OrganizationName: "MyCustomAid", + }) + require.NoError(t, err) + + msg := message.Message{ + ToEmail: u.Email, + Title: invitationMessageTitle, + Message: content, + } + messengerClientMock. + On("SendMessage", msg). + Return(nil). + Once() + + body := ` + { + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "roles": ["developer"] + } + ` + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "id": "user-id", + "first_name": "First", + "last_name": "Last", + "email": "email@email.com", + "is_active": true, + "roles": ["developer"] + } + ` + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) + + authenticatorMock.AssertExpectations(t) + messengerClientMock.AssertExpectations(t) +} + +func Test_UpdateRolesRequest_validate(t *testing.T) { + upr := UpdateRolesRequest{ + UserID: "", + Roles: []data.UserRole{}, + } + + extras := map[string]interface{}{ + "user_id": "user_id is required", + "roles": "the number of roles required is exactly one", + } + expectedErr := httperror.BadRequest("Request invalid", nil, extras) + + err := upr.validate() + assert.Equal(t, expectedErr, err) + + upr = UpdateRolesRequest{ + UserID: "user_id", + Roles: []data.UserRole{data.BusinessUserRole, data.DeveloperUserRole}, + } + + extras = map[string]interface{}{ + "roles": "the number of roles required is exactly one", + } + expectedErr = httperror.BadRequest("Request invalid", nil, extras) + + err = upr.validate() + assert.Equal(t, expectedErr, err) + + upr = UpdateRolesRequest{ + UserID: "user_id", + Roles: []data.UserRole{data.DeveloperUserRole}, + } + + err = upr.validate() + assert.Nil(t, err) +} + +func Test_UserHandler_UpdateUserRoles(t *testing.T) { + r := chi.NewRouter() + + jwtManagerMock := &auth.JWTManagerMock{} + roleManagerMock := &auth.RoleManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomRoleManagerOption(roleManagerMock), + ) + + handler := &UserHandler{AuthManager: authManager} + + const url = "/users/roles" + r.Patch(url, handler.UpdateUserRoles) + + t.Run("returns Unauthorized when no token is in the request context", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns error when request body is invalid", func(t *testing.T) { + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, "mytoken") + + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`{}`)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + { + "error": "Request invalid", + "extras": { + "user_id": "user_id is required", + "roles": "the number of roles required is exactly one" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + body := ` + { + "user_id": "user-id", + "roles": ["role1", "role2"] + } + ` + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(body)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "roles": "the number of roles required is exactly one" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + body = ` + { + "user_id": "user-id", + "roles": ["role1"] + } + ` + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(body)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "Request invalid", + "extras": { + "roles": "unexpected value for roles[0]=role1. Expect one of these values: [owner financial_controller developer business]" + } + } + ` + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + req, err = http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(`"invalid"`)) + require.NoError(t, err) + + w = httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp = w.Result() + + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody = ` + { + "error": "The request was invalid in some way." + } + ` + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + assert.Contains(t, buf.String(), "decoding the request body") + }) + + t.Run("returns Unauthorized when token is invalid", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(false, nil). + Once() + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + reqBody := ` + { + "user_id": "user-id", + "roles": ["developer"] + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns BadRequest when user doesn't exist", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(true, nil). + Once() + + roleManagerMock. + On("UpdateRoles", mock.Anything, &auth.User{ID: "user-id"}, []string{data.DeveloperUserRole.String()}). + Return(auth.ErrNoRowsAffected). + Once() + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + reqBody := ` + { + "user_id": "user-id", + "roles": ["developer"] + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "The request was invalid in some way.", "extras": {"user_id":"user_id is invalid"}}`, string(respBody)) + }) + + t.Run("returns InternalServerError when a unexpected error occurs", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(false, errors.New("unexpected error")). + Once() + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + reqBody := ` + { + "user_id": "user-id", + "roles": ["developer"] + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot update user activation"}`, string(respBody)) + assert.Contains(t, buf.String(), "Cannot update user activation") + }) + + t.Run("updates the user activation correctly", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(true, nil). + Once() + + roleManagerMock. + On("UpdateRoles", mock.Anything, &auth.User{ID: "user-id"}, []string{data.DeveloperUserRole.String()}). + Return(nil). + Once() + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + reqBody := ` + { + "user_id": "user-id", + "roles": ["developer"] + } + ` + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, url, strings.NewReader(reqBody)) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"message": "user roles were updated successfully"}`, string(respBody)) + }) +} + +func Test_UserHandler_GetAllUsers(t *testing.T) { + jwtManagerMock := &auth.JWTManagerMock{} + authenticatorMock := &auth.AuthenticatorMock{} + authManager := auth.NewAuthManager( + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomAuthenticatorOption(authenticatorMock), + ) + + handler := &UserHandler{AuthManager: authManager} + + const url = "/users" + + t.Run("returns Unauthorized when no token is in the request context", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + http.HandlerFunc(handler.GetAllUsers).ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns Unauthorized when token is invalid", func(t *testing.T) { + token := "mytoken" + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), token). + Return(false, nil). + Once() + + w := httptest.NewRecorder() + + http.HandlerFunc(handler.GetAllUsers).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns InternalServerError when a unexpected error occurs", func(t *testing.T) { + token := "mytoken" + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), token). + Return(false, errors.New("unexpected error")). + Once() + + w := httptest.NewRecorder() + + http.HandlerFunc(handler.GetAllUsers).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot get all users"}`, string(respBody)) + assert.Contains(t, buf.String(), "Cannot get all users") + }) + + t.Run("returns all users successfully", func(t *testing.T) { + token := "mytoken" + + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + jwtManagerMock. + On("ValidateToken", req.Context(), token). + Return(true, nil). + Once() + + authenticatorMock. + On("GetAllUsers", req.Context()). + Return([]auth.User{ + { + ID: "user1-ID", + FirstName: "First", + LastName: "Last", + Email: "user1@email.com", + IsOwner: false, + IsActive: false, + Roles: []string{data.BusinessUserRole.String()}, + }, + { + ID: "user2-ID", + FirstName: "First", + LastName: "Last", + Email: "user2@email.com", + IsOwner: true, + IsActive: true, + Roles: []string{data.OwnerUserRole.String()}, + }, + }, nil). + Once() + + w := httptest.NewRecorder() + + http.HandlerFunc(handler.GetAllUsers).ServeHTTP(w, req) + + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + wantsBody := ` + [ + { + "id": "user1-ID", + "first_name": "First", + "last_name": "Last", + "email": "user1@email.com", + "is_active": false, + "roles": [ + "business" + ] + }, + { + "id": "user2-ID", + "first_name": "First", + "last_name": "Last", + "email": "user2@email.com", + "is_active": true, + "roles": [ + "owner" + ] + } + ] + ` + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, wantsBody, string(respBody)) + }) +} diff --git a/internal/serve/httphandler/verifiy_receiver_registration_handler.go b/internal/serve/httphandler/verifiy_receiver_registration_handler.go new file mode 100644 index 000000000..5c01e91b9 --- /dev/null +++ b/internal/serve/httphandler/verifiy_receiver_registration_handler.go @@ -0,0 +1,226 @@ +package httphandler + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +// ErrorInformationNotFound implements the error interface. +type ErrorInformationNotFound struct { + cause error +} + +func (e *ErrorInformationNotFound) Error() string { + return e.cause.Error() +} + +const ( + InformationNotFoundOnServer = "the information you provided could not be found in our server" +) + +type VerifyReceiverRegistrationHandler struct { + AnchorPlatformAPIService anchorplatform.AnchorPlatformAPIServiceInterface + Models *data.Models + ReCAPTCHAValidator validators.ReCAPTCHAValidator + NetworkPassphrase string +} + +// VerifyReceiverRegistration implements the http.Handler interface. +func (v VerifyReceiverRegistrationHandler) VerifyReceiverRegistration(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // claims sep24 Token + sep24Claims := anchorplatform.GetSEP24Claims(ctx) + if sep24Claims == nil { + err := fmt.Errorf("no SEP-24 claims found in the request context") + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(w) + return + } + + // decode request payload into ReceiverRegistrationRequest + receiverRegistrationRequest := data.ReceiverRegistrationRequest{} + err := json.NewDecoder(r.Body).Decode(&receiverRegistrationRequest) + if err != nil { + log.Ctx(ctx).Errorf("invalid request body: %s", err.Error()) + httperror.BadRequest("invalid request body", err, nil).Render(w) + return + } + + // validating reCAPTCHA Token + isValid, err := v.ReCAPTCHAValidator.IsTokenValid(ctx, receiverRegistrationRequest.ReCAPTCHAToken) + if err != nil { + httperror.InternalError(ctx, "Cannot validate reCAPTCHA token", err, nil).Render(w) + return + } + + if !isValid { + log.Ctx(ctx).Errorf("reCAPTCHA token is invalid for request with OTP %s and Phone Number %s", + utils.TruncateString(receiverRegistrationRequest.OTP, 2), utils.TruncateString(receiverRegistrationRequest.PhoneNumber, 4)) + httperror.BadRequest("request invalid", nil, nil).Render(w) + return + } + + // validate request payload + validator := validators.NewReceiverRegistrationValidator() + validator.ValidateReceiver(&receiverRegistrationRequest) + if validator.HasErrors() { + log.Ctx(ctx).Errorf("request invalid: %s", validator.Errors) + httperror.BadRequest("request invalid", nil, validator.Errors).Render(w) + return + } + + err = db.RunInTransaction(ctx, v.Models.DBConnectionPool, nil, func(dbTx db.DBTransaction) error { + // get receiver with the phone number present in the payload + receiver, innerErr := v.Models.Receiver.GetByPhoneNumbers(ctx, dbTx, []string{receiverRegistrationRequest.PhoneNumber}) + if innerErr != nil { + log.Ctx(ctx).Errorf("error retrieving receiver with phone number %s: %s", utils.TruncateString(receiverRegistrationRequest.PhoneNumber, 3), innerErr.Error()) + return innerErr + } + if len(receiver) == 0 { + innerErr = fmt.Errorf("receiver with phone number %s not found in our server", receiverRegistrationRequest.PhoneNumber) + log.Ctx(ctx).Error(innerErr) + return &ErrorInformationNotFound{cause: innerErr} + } + + // get receiverVerification using receiver ID and the verification type + receiverVerifications, innerErr := v.Models.ReceiverVerification.GetByReceiverIdsAndVerificationField(ctx, dbTx, []string{receiver[0].ID}, receiverRegistrationRequest.VerificationType) + if innerErr != nil { + log.Ctx(ctx).Errorf("error retrieving receiver verification for verification type %s", receiverRegistrationRequest.VerificationType) + return innerErr + } + if len(receiverVerifications) == 0 { + innerErr = fmt.Errorf("%s not found for receiver with phone number %s", receiverRegistrationRequest.VerificationType, receiverRegistrationRequest.PhoneNumber) + log.Ctx(ctx).Error(innerErr) + return &ErrorInformationNotFound{cause: innerErr} + } + + if len(receiverVerifications) > 1 { + log.Ctx(ctx).Warnf("receiver with id %s has more than one verification saved in the database for type %s", receiver[0].ID, receiverRegistrationRequest.VerificationType) + } + + receiverVerification := receiverVerifications[0] + + if v.Models.ReceiverVerification.ExceededAttempts(receiverVerification.Attempts) { + innerErr = fmt.Errorf("number of attempts to confirm the verification value exceeded max attempts value %d", data.MaxAttemptsAllowed) + log.Ctx(ctx).Error(innerErr) + return &ErrorInformationNotFound{cause: innerErr} + } + + now := time.Now() + // check if verification value match with value saved in the database + if !data.CompareVerificationValue(receiverVerification.HashedValue, receiverRegistrationRequest.VerificationValue) { + baseErr := fmt.Sprintf("%s value does not match for user with phone number %s", receiverRegistrationRequest.VerificationType, receiverRegistrationRequest.PhoneNumber) + // update the receiver verification with the confirmation that the value was checked + receiverVerification.Attempts = receiverVerification.Attempts + 1 + receiverVerification.FailedAt = &now + receiverVerification.ConfirmedAt = nil + + // this update is done using the DBConnectionPool and not dbTx because we don't want to roolback these changes after returning the error + updateErr := v.Models.ReceiverVerification.UpdateReceiverVerification(ctx, *receiverVerification, v.Models.DBConnectionPool) + if updateErr != nil { + innerErr = fmt.Errorf("%s: %w", baseErr, updateErr) + } else { + innerErr = fmt.Errorf("%s", baseErr) + } + + log.Ctx(ctx).Error(innerErr) + return &ErrorInformationNotFound{cause: innerErr} + } + + // update the receiver verification with the confirmation that the value was checked + if receiverVerification.ConfirmedAt == nil { + receiverVerification.ConfirmedAt = &now + innerErr = v.Models.ReceiverVerification.UpdateReceiverVerification(ctx, *receiverVerification, dbTx) + if innerErr != nil { + log.Ctx(ctx).Error(innerErr) + return &ErrorInformationNotFound{cause: innerErr} + } + } + + receiverWallet, innerErr := v.Models.ReceiverWallet.GetByReceiverIDAndWalletDomain(ctx, receiver[0].ID, sep24Claims.ClientDomain(), dbTx) + if innerErr != nil { + log.Ctx(ctx).Errorf("receiver wallet not found for receiver with id %s and client domain %s", receiver[0].ID, sep24Claims.ClientDomain()) + return &ErrorInformationNotFound{cause: innerErr} + } + + // check if receiver is already registered + if receiverWallet.Status == data.RegisteredReceiversWalletStatus { + log.Ctx(ctx).Info("receiver already registered in the SDP") + return nil + } + + // check if receiver wallet status can transition to RegisteredReceiversWalletStatus + innerErr = receiverWallet.Status.TransitionTo(data.RegisteredReceiversWalletStatus) + if innerErr != nil { + log.Ctx(ctx).Errorf("receiver wallet for receiver with id %s has an invalid status %s, can not transaction to REGISTERED", receiver[0].ID, receiverWallet.Status) + return &ErrorInformationNotFound{cause: innerErr} + } + + // check if receiver_wallet OTP is valid and not expired + innerErr = v.Models.ReceiverWallet.VerifyReceiverWalletOTP(ctx, v.NetworkPassphrase, *receiverWallet, receiverRegistrationRequest.OTP) + if innerErr != nil { + log.Ctx(ctx).Errorf("receiver wallet otp is not valid: %s", innerErr.Error()) + return &ErrorInformationNotFound{cause: innerErr} + } + + // update transaction on AnchorPlatform using AnchorPlatformAPIService + transaction := &anchorplatform.Transaction{ + TransactionValues: anchorplatform.TransactionValues{ + ID: sep24Claims.TransactionID(), + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: sep24Claims.SEP10StellarAccount(), + Memo: sep24Claims.SEP10StellarMemo(), + KYCVerified: true, + }, + } + + // update receiver wallet + receiverWallet.StellarAddress = sep24Claims.SEP10StellarAccount() + if sep24Claims.SEP10StellarMemo() != "" { + receiverWallet.StellarMemo = sep24Claims.SEP10StellarMemo() + receiverWallet.StellarMemoType = "id" + } + receiverWallet.Status = data.RegisteredReceiversWalletStatus + + innerErr = v.Models.ReceiverWallet.UpdateReceiverWallet(ctx, *receiverWallet, dbTx) + if innerErr != nil { + log.Ctx(ctx).Errorf("error updating receiver wallet status to registered for receiver with phone number %s", utils.TruncateString(receiverRegistrationRequest.PhoneNumber, 3)) + return innerErr + } + + innerErr = v.AnchorPlatformAPIService.UpdateAnchorTransactions(ctx, []anchorplatform.Transaction{*transaction}) + if innerErr != nil { + innerErr = fmt.Errorf("error updating transaction with ID %s on anchor platform API: %w", sep24Claims.TransactionID(), innerErr) + return innerErr + } + + return nil + }) + + if err != nil { + var errorInformationNotFound *ErrorInformationNotFound + if errors.As(err, &errorInformationNotFound) { + httperror.BadRequest(InformationNotFoundOnServer, err, nil).Render(w) + return + } + httperror.InternalError(ctx, "", err, nil).Render(w) + return + } + + httpjson.RenderStatus(w, http.StatusOK, map[string]string{"message": "ok"}, httpjson.JSON) +} diff --git a/internal/serve/httphandler/verifiy_receiver_registration_handler_test.go b/internal/serve/httphandler/verifiy_receiver_registration_handler_test.go new file mode 100644 index 000000000..6bc7d4f18 --- /dev/null +++ b/internal/serve/httphandler/verifiy_receiver_registration_handler_test.go @@ -0,0 +1,1283 @@ +package httphandler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/golang-jwt/jwt/v4" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_VerifyReceiverRegistration(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + mockAnchorPlatformService := anchorplatform.AnchorPlatformAPIServiceMock{} + reCAPTCHAValidator := &validators.ReCAPTCHAValidatorMock{} + handler := &VerifyReceiverRegistrationHandler{ + Models: models, + AnchorPlatformAPIService: &mockAnchorPlatformService, + ReCAPTCHAValidator: reCAPTCHAValidator, + } + + // setup + r := chi.NewRouter() + r.Post("/wallet-registration/verification", handler.VerifyReceiverRegistration) + + t.Run("error unauthorized sep24 token not found", func(t *testing.T) { + req, err := http.NewRequest("POST", "/wallet-registration/verification", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + ctx := context.Background() + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "testWallet", "https://home.page", "home.page", "wallet123://") + + t.Run("error internal server error when the reCAPTCHA validator fails", func(t *testing.T) { + reqBody := ` + { + "phone_number": "+380445555555", + "otp": "123456", + "verification_value": "1990-01-01", + "verification_type": "date_of_birth", + "reCAPTCHA_token": "token" + } + ` + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "/wallet-registration/verification", strings.NewReader(reqBody)) + require.NoError(t, err) + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(false, errors.New("unexpected error")). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error": "Cannot validate reCAPTCHA token"}`, string(respBody)) + + entries := getEntries() + assert.NotEmpty(t, entries) + assert.Equal(t, "Cannot validate reCAPTCHA token: unexpected error", entries[0].Message) + }) + + t.Run("error bad request when the reCAPTCHA token is invalid", func(t *testing.T) { + reqBody := ` + { + "phone_number": "+380445555555", + "otp": "123456", + "verification_value": "1990-01-01", + "verification_type": "date_of_birth", + "reCAPTCHA_token": "token" + } + ` + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "/wallet-registration/verification", strings.NewReader(reqBody)) + require.NoError(t, err) + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(false, nil). + Once() + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, `{"error": "request invalid"}`, string(respBody)) + + entries := getEntries() + assert.NotEmpty(t, entries) + assert.Equal(t, "reCAPTCHA token is invalid for request with OTP 12...56 and Phone Number +380...5555", entries[0].Message) + }) + + t.Run("error invalid request body", func(t *testing.T) { + testCases := []struct { + name string + request data.ReceiverRegistrationRequest + want string + }{ + { + name: "empty phone number", + request: data.ReceiverRegistrationRequest{ + PhoneNumber: "", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + }, + want: ` + { + "error": "request invalid", + "extras": { + "phone_number": "phone cannot be empty" + } + } + `, + }, + { + name: "invalid phone number", + request: data.ReceiverRegistrationRequest{ + PhoneNumber: "invalid_phone", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + }, + want: ` + { + "error": "request invalid", + "extras": { + "phone_number": "invalid phone format. Correct format: +380445555555" + } + } + `, + }, + { + name: "invalid otp", + request: data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "12mock", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + }, + want: ` + { + "error": "request invalid", + "extras": { + "otp": "invalid otp format. Needs to be a 6 digit value" + } + } + `, + }, + { + name: "invalid verification type", + request: data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "invalid", + ReCAPTCHAToken: "token", + }, + want: ` + { + "error": "request invalid", + "extras": { + "verification_type": "invalid parameter. valid values are: DATE_OF_BIRTH, PIN, NATIONAL_ID_NUMBER" + } + } + `, + }, + { + name: "invalid verification value", + request: data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "90/01/01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + }, + want: ` + { + "error": "request invalid", + "extras": { + "verification": "invalid date of birth format. Correct format: 1990-01-01" + } + } + `, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reqBody, err := json.Marshal(tc.request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + rr := httptest.NewRecorder() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, tc.want, string(respBody)) + }) + } + }) + + t.Run("error receiver not found in our server", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "receiver with phone number +380445555555 not found in our server") + }) + + t.Run("error receiver verification not found in our server", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + _ = data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "DATE_OF_BIRTH not found for receiver with phone number +380445555555") + }) + + t.Run("error comparing verification values exceeded attempts", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + receiverVerification := data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "2000-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + + reqBody, err := json.Marshal(request) + require.NoError(t, err) + + attempts := 0 + + const totalAttempts = data.MaxAttemptsAllowed + 1 + for range [totalAttempts]interface{}{} { + buf.Reset() + + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + r.ServeHTTP(rr, req) + + attempts += 1 + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate the number of attempts + query := ` + SELECT + rv.attempts + FROM + receiver_verifications rv + WHERE + rv.receiver_id = $1 AND rv.verification_field = $2 + ` + receiverVerificationUpdated := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &receiverVerificationUpdated, query, receiverVerification.ReceiverID, receiverVerification.VerificationField) + require.NoError(t, err) + + expectedLog := "" + if attempts == totalAttempts { + expectedLog = "number of attempts to confirm the verification value exceeded max attempts value 6" + assert.Equal(t, data.MaxAttemptsAllowed, receiverVerificationUpdated.Attempts) + } else { + expectedLog = "DATE_OF_BIRTH value does not match for user with phone number +380445555555" + assert.Equal(t, attempts, receiverVerificationUpdated.Attempts) + } + + // validate logs + require.Contains(t, buf.String(), expectedLog) + } + }) + + t.Run("error comparing verification values", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + receiverVerification := data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "2000-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate if the receiver verification has been updated + query := ` + SELECT + rv.attempts, + rv.confirmed_at, + rv.failed_at + FROM + receiver_verifications rv + WHERE + rv.receiver_id = $1 AND rv.verification_field = $2 + ` + receiverVerificationUpdated := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &receiverVerificationUpdated, query, receiverVerification.ReceiverID, receiverVerification.VerificationField) + require.NoError(t, err) + + assert.Empty(t, receiverVerificationUpdated.ConfirmedAt) + assert.NotEmpty(t, receiverVerificationUpdated.FailedAt) + assert.Equal(t, 1, receiverVerificationUpdated.Attempts) + + // validate logs + require.Contains(t, buf.String(), "DATE_OF_BIRTH value does not match for user with phone number +380445555555") + }) + + t.Run("error getting receiver wallet", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + msg := fmt.Sprintf("receiver wallet not found for receiver with id %s and client domain home.page", receiver.ID) + require.Contains(t, buf.String(), msg) + }) + + t.Run("error receiver wallet otp does not match the value saved in the database", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "111111", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "receiver wallet otp is not valid: otp does not match with value saved in the database") + }) + + t.Run("error receiver wallet otp is expired", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + query := ` + UPDATE + receiver_wallets rw + SET + otp_created_at = $1 + WHERE + rw.stellar_address = $2 + ` + expiredOTPCreatedAt := time.Now().Add(-data.OTPExpirationTimeMinutes * time.Minute).Add(-time.Second) // expired 1 second ago + _, err = dbConnectionPool.ExecContext(ctx, query, expiredOTPCreatedAt, receiverWallet.StellarAddress) + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{"error": "%s"}`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "receiver wallet otp is not valid: otp is expired") + }) + + t.Run("error anchor platform service API", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + // set stellar values to empty + query := ` + UPDATE + receiver_wallets rw + SET + stellar_address = '', + stellar_memo = '', + stellar_memo_type = '' + WHERE + rw.id = $1 + ` + _, err = dbConnectionPool.ExecContext(ctx, query, receiverWallet.ID) + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + transaction := &anchorplatform.Transaction{ + TransactionValues: anchorplatform.TransactionValues{ + ID: "test-transaction-id", + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: validClaims.SEP10StellarAccount(), + Memo: validClaims.SEP10StellarMemo(), + KYCVerified: true, + }, + } + mockAnchorPlatformService. + On("UpdateAnchorTransactions", mock.Anything, []anchorplatform.Transaction{*transaction}). + Return(fmt.Errorf("error updating transaction on anchor platform")).Once() + + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := `{ + "error": "An internal error occurred while processing this request." + } + ` + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate if the receiver wallet has been updated + query = ` + SELECT + rw.status, + rw.stellar_address, + rw.stellar_memo, + rw.stellar_memo_type, + otp_confirmed_at + FROM + receiver_wallets rw + WHERE + rw.id = $1 + ` + receiverWalletUpdated := data.ReceiverWallet{} + err = dbConnectionPool.GetContext(ctx, &receiverWalletUpdated, query, receiverWallet.ID) + require.NoError(t, err) + + assert.Equal(t, data.ReadyReceiversWalletStatus, receiverWalletUpdated.Status) + assert.Empty(t, receiverWalletUpdated.StellarAddress) + assert.Empty(t, receiverWalletUpdated.StellarMemo) + assert.Empty(t, receiverWalletUpdated.StellarMemoType) + require.Empty(t, receiverWalletUpdated.OTPConfirmedAt) + + // validate logs + require.Contains(t, buf.String(), "error updating transaction with ID test-transaction-id on anchor platform API") + mockAnchorPlatformService.AssertExpectations(t) + }) + + t.Run("receiver already registered", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.InfoLevel) + + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := `{ + "message": "ok" + } + ` + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + require.Contains(t, buf.String(), "receiver already registered in the SDP") + }) + + t.Run("invalid receiver wallet status", func(t *testing.T) { + // set the logger to a buffer so we can check the error message + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := fmt.Sprintf(`{ + "error": "%s" + }`, InformationNotFoundOnServer) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate logs + msg := fmt.Sprintf("receiver wallet for receiver with id %s has an invalid status DRAFT, can not transaction to REGISTERED", receiver.ID) + require.Contains(t, buf.String(), msg) + }) + + t.Run("successfully verifying receiver registration without stellar memo", func(t *testing.T) { + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + transaction := &anchorplatform.Transaction{ + TransactionValues: anchorplatform.TransactionValues{ + ID: "test-transaction-id", + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: validClaims.SEP10StellarAccount(), + Memo: validClaims.SEP10StellarMemo(), + KYCVerified: true, + }, + } + mockAnchorPlatformService. + On("UpdateAnchorTransactions", mock.Anything, []anchorplatform.Transaction{*transaction}). + Return(nil).Once() + + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := `{ + "message": "ok" + } + ` + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate if the receiver wallet has been updated + query := ` + SELECT + rw.status, + rw.stellar_address, + COALESCE(rw.stellar_memo, '') as "stellar_memo", + COALESCE(rw.stellar_memo_type, '') as "stellar_memo_type", + otp_confirmed_at + FROM + receiver_wallets rw + WHERE + rw.id = $1 + ` + receiverWalletUpdated := data.ReceiverWallet{} + err = dbConnectionPool.GetContext(ctx, &receiverWalletUpdated, query, receiverWallet.ID) + require.NoError(t, err) + + assert.Equal(t, data.RegisteredReceiversWalletStatus, receiverWalletUpdated.Status) + assert.Equal(t, "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", receiverWalletUpdated.StellarAddress) + assert.Empty(t, receiverWalletUpdated.StellarMemo) + assert.Empty(t, receiverWalletUpdated.StellarMemoType) + require.NotEmpty(t, receiverWalletUpdated.OTPConfirmedAt) + + mockAnchorPlatformService.AssertExpectations(t) + }) + + t.Run("successfully verifying receiver registration with stellar memo", func(t *testing.T) { + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverVerificationFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{ + PhoneNumber: "+380445555555", + }) + + receiverVerification := data.CreateReceiverVerificationFixture(t, ctx, dbConnectionPool, data.ReceiverVerificationInsert{ + ReceiverID: receiver.ID, + VerificationField: data.VerificationFieldDateOfBirth, + VerificationValue: "1990-01-01", + }) + + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.ReadyReceiversWalletStatus) + _, err := models.ReceiverWallet.UpdateOTPByReceiverPhoneNumberAndWalletDomain(ctx, "+380445555555", wallet.SEP10ClientDomain, "123456") + require.NoError(t, err) + + request := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "date_of_birth", + ReCAPTCHAToken: "token", + } + + reqBody, err := json.Marshal(request) + require.NoError(t, err) + req, err := http.NewRequest("POST", "/wallet-registration/verification", strings.NewReader(string(reqBody))) + require.NoError(t, err) + + rr := httptest.NewRecorder() + + reCAPTCHAValidator. + On("IsTokenValid", mock.Anything, "token"). + Return(true, nil). + Once() + + // create valid sep24 token + validClaims := &anchorplatform.SEP24JWTClaims{ + ClientDomainClaim: wallet.SEP10ClientDomain, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "test-transaction-id", + Subject: "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444:123456", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + } + req = req.WithContext(context.WithValue(req.Context(), anchorplatform.SEP24ClaimsContextKey, validClaims)) + + transaction := &anchorplatform.Transaction{ + TransactionValues: anchorplatform.TransactionValues{ + ID: "test-transaction-id", + Status: "pending_anchor", + Sep: "24", + Kind: "deposit", + DestinationAccount: validClaims.SEP10StellarAccount(), + Memo: validClaims.SEP10StellarMemo(), + KYCVerified: true, + }, + } + mockAnchorPlatformService. + On("UpdateAnchorTransactions", mock.Anything, []anchorplatform.Transaction{*transaction}). + Return(nil).Once() + + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + want := `{ + "message": "ok" + } + ` + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, want, string(respBody)) + + // validate if the receiver wallet has been updated + query := ` + SELECT + rw.status, + rw.stellar_address, + rw.stellar_memo, + rw.stellar_memo_type, + otp_confirmed_at + FROM + receiver_wallets rw + WHERE + rw.id = $1 + ` + receiverWalletUpdated := data.ReceiverWallet{} + err = dbConnectionPool.GetContext(ctx, &receiverWalletUpdated, query, receiverWallet.ID) + require.NoError(t, err) + + assert.Equal(t, data.RegisteredReceiversWalletStatus, receiverWalletUpdated.Status) + assert.Equal(t, "GBLTXF46JTCGMWFJASQLVXMMA36IPYTDCN4EN73HRXCGDCGYBZM3A444", receiverWalletUpdated.StellarAddress) + assert.Equal(t, "123456", receiverWalletUpdated.StellarMemo) + assert.Equal(t, "id", receiverWalletUpdated.StellarMemoType) + require.NotEmpty(t, receiverWalletUpdated.OTPConfirmedAt) + + // validate if the receiver verification field confirmed_at has been updated + query = ` + SELECT + rv.confirmed_at + FROM + receiver_verifications rv + WHERE + rv.receiver_id = $1 AND rv.verification_field = $2 + ` + receiverVerificationUpdated := data.ReceiverVerification{} + err = dbConnectionPool.GetContext(ctx, &receiverVerificationUpdated, query, receiverVerification.ReceiverID, receiverVerification.VerificationField) + require.NoError(t, err) + + assert.NotEmpty(t, receiverVerificationUpdated.ConfirmedAt) + + mockAnchorPlatformService.AssertExpectations(t) + }) + + reCAPTCHAValidator.AssertExpectations(t) +} diff --git a/internal/serve/httphandler/wallets_handler.go b/internal/serve/httphandler/wallets_handler.go new file mode 100644 index 000000000..8d37361f0 --- /dev/null +++ b/internal/serve/httphandler/wallets_handler.go @@ -0,0 +1,23 @@ +package httphandler + +import ( + "net/http" + + "github.com/stellar/go/support/render/httpjson" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" +) + +type WalletsHandler struct { + Models *data.Models +} + +// GetWallets returns a list of wallets +func (c WalletsHandler) GetWallets(w http.ResponseWriter, r *http.Request) { + countries, err := c.Models.Wallets.GetAll(r.Context()) + if err != nil { + httperror.InternalError(r.Context(), "Cannot retrieve list of wallets", err, nil).Render(w) + return + } + httpjson.Render(w, countries, httpjson.JSON) +} diff --git a/internal/serve/httphandler/wallets_handler_test.go b/internal/serve/httphandler/wallets_handler_test.go new file mode 100644 index 000000000..30aba5b8e --- /dev/null +++ b/internal/serve/httphandler/wallets_handler_test.go @@ -0,0 +1,53 @@ +package httphandler + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_WalletsHandlerGetWallets(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + handler := &WalletsHandler{ + Models: models, + } + + t.Run("successfully returns a list of countries", func(t *testing.T) { + expected := data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool) + expectedJSON, err := json.Marshal(expected) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/wallets", nil) + http.HandlerFunc(handler.GetWallets).ServeHTTP(rr, req) + + resp := rr.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + assert.JSONEq(t, string(expectedJSON), string(respBody)) + }) +} diff --git a/internal/serve/httpresponse/paginated_response.go b/internal/serve/httpresponse/paginated_response.go new file mode 100644 index 000000000..559ea84e3 --- /dev/null +++ b/internal/serve/httpresponse/paginated_response.go @@ -0,0 +1,65 @@ +package httpresponse + +import ( + "encoding/json" + "fmt" + "net/http" +) + +// PaginatedResponse is a response that contains pagination information. +type PaginatedResponse struct { + Pagination PaginationInfo `json:"pagination"` + Data json.RawMessage `json:"data"` +} + +type PaginationInfo struct { + Next string `json:"next,omitempty"` + Prev string `json:"prev,omitempty"` + Pages int `json:"pages"` + Total int `json:"total"` +} + +// NewEmptyPaginatedResponse returns a PaginatedResponse with an empty data and 0 pages. +// +// This is useful for returning an empty list. +func NewEmptyPaginatedResponse() PaginatedResponse { + return PaginatedResponse{ + Pagination: PaginationInfo{ + Pages: 0, + Total: 0, + }, + Data: json.RawMessage("[]"), + } +} + +// NewPaginatedResponse returns a PaginatedResponse with pagination information. +func NewPaginatedResponse(r *http.Request, data interface{}, currentPage, pageLimit, totalItems int) (PaginatedResponse, error) { + totalPages := (totalItems + pageLimit - 1) / pageLimit + pagination := PaginationInfo{Pages: totalPages, Total: totalItems} + + baseURL := *r.URL + q := baseURL.Query() + q.Del("page") + + if currentPage < totalPages { + q.Set("page", fmt.Sprintf("%d", currentPage+1)) + baseURL.RawQuery = q.Encode() + pagination.Next = baseURL.String() + } + + if currentPage > 1 { + q.Set("page", fmt.Sprintf("%d", currentPage-1)) + baseURL.RawQuery = q.Encode() + pagination.Prev = baseURL.String() + } + + dataBytes, err := json.Marshal(data) + if err != nil { + return PaginatedResponse{}, err + } + + return PaginatedResponse{ + Pagination: pagination, + Data: dataBytes, + }, nil +} diff --git a/internal/serve/middleware/middleware.go b/internal/serve/middleware/middleware.go new file mode 100644 index 000000000..54f34fe5f --- /dev/null +++ b/internal/serve/middleware/middleware.go @@ -0,0 +1,203 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5/middleware" + "github.com/rs/cors" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +type ContextKey string + +const TokenContextKey ContextKey = "auth_token" + +// RecoverHandler is a middleware that recovers from panics and logs the error. +func RecoverHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + defer func() { + r := recover() + if r == nil { + return + } + err, ok := r.(error) + if !ok { + err = fmt.Errorf("panic: %v", r) + } + + // No need to recover when the client has disconnected: + if errors.Is(err, http.ErrAbortHandler) { + panic(err) + } + + ctx := req.Context() + log.Ctx(ctx).WithStack(err).Error(err) + httperror.InternalError(ctx, "", err, nil).Render(rw) + }() + + next.ServeHTTP(rw, req) + }) +} + +// MetricsRequestHandler is a middleware that monitors http requests, and export the data +// to the metrics server +func MetricsRequestHandler(monitorService monitor.MonitorServiceInterface) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + mw := middleware.NewWrapResponseWriter(rw, req.ProtoMajor) + then := time.Now() + next.ServeHTTP(mw, req) + + duration := time.Since(then) + + labels := monitor.HttpRequestLabels{ + Status: fmt.Sprintf("%d", mw.Status()), + Route: utils.GetRoutePattern(req), + Method: req.Method, + } + + err := monitorService.MonitorHttpRequestDuration(duration, labels) + if err != nil { + log.Errorf("Error trying to monitor request time: %s", err) + } + }) + } +} + +// AuthenticateMiddleware is a middleware that validates the Authorization header for +// authenticated endpoints. +func AuthenticateMiddleware(authManager auth.AuthManager) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + authHeader := req.Header.Get("Authorization") + if authHeader == "" { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 { + httperror.Unauthorized("Invalid Authorization header provided.", nil, nil).Render(rw) + return + } + + ctx := req.Context() + token := authHeaderParts[1] + isValid, err := authManager.ValidateToken(ctx, token) + if err != nil { + err = fmt.Errorf("error validating auth token: %w", err) + log.Ctx(ctx).Error(err) + httperror.Unauthorized("", err, nil).Render(rw) + return + } + + if !isValid { + httperror.Unauthorized("Invalid token provided.", nil, nil).Render(rw) + return + } + + // Add the token to the request context + ctx = context.WithValue(ctx, TokenContextKey, token) + req = req.WithContext(ctx) + + next.ServeHTTP(rw, req) + }) + } +} + +// AnyRoleMiddleware validates if the user has at least one of the required roles to request +// the current endpoint. +func AnyRoleMiddleware(authManager auth.AuthManager, requiredRoles ...data.UserRole) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + + token, ok := ctx.Value(TokenContextKey).(string) + if !ok { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + // Accessible by all users + if len(requiredRoles) == 0 { + next.ServeHTTP(rw, req) + return + } + + isValid, err := authManager.AnyRolesInTokenUser(ctx, token, data.FromUserRoleArrayToStringArray(requiredRoles)) + if err != nil && !errors.Is(err, auth.ErrInvalidToken) { + httperror.InternalError(ctx, "", err, nil).Render(rw) + return + } + + if !isValid { + httperror.Unauthorized("", nil, nil).Render(rw) + return + } + + next.ServeHTTP(rw, req) + }) + } +} + +func CorsMiddleware(corsAllowedOrigins []string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + cors := cors.New(cors.Options{ + AllowedOrigins: corsAllowedOrigins, + AllowedHeaders: []string{"*"}, + AllowedMethods: []string{"GET", "PUT", "POST", "PATCH", "DELETE", "HEAD", "OPTIONS"}, + }) + + return cors.Handler(next) + } +} + +type cspItem struct { + ContentType string + Policy []string +} + +func (c cspItem) String() string { + return fmt.Sprintf("%s %s;", c.ContentType, strings.Join(c.Policy, " ")) +} + +// CSPMiddleware is the middleware that sets the content security policy, restricting content to only be accessed +// from specified sources in the header. +func CSPMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + selfSrc := "'self'" + recaptchaSrc := "https://www.google.com/recaptcha/" + cspItems := []cspItem{ + {"script-src", []string{selfSrc, recaptchaSrc, "https://www.gstatic.com/recaptcha/"}}, + {"style-src", []string{selfSrc, recaptchaSrc, "https://fonts.googleapis.com/css2", "'unsafe-inline'"}}, + {"connect-src", []string{selfSrc, recaptchaSrc}}, + {"font-src", []string{selfSrc, "https://fonts.gstatic.com"}}, + {"default-src", []string{selfSrc}}, + + {"frame-src", []string{selfSrc, recaptchaSrc}}, + {"frame-ancestors", []string{selfSrc}}, + + {"form-action", []string{selfSrc}}, + } + cspStr := "" + for _, item := range cspItems { + cspStr += fmt.Sprintf("%v", item) + } + + // policyStr := "default-src 'self'; script-src 'self'; frame-ancestors 'self'; form-action 'self';" + rw.Header().Set("Content-Security-Policy", cspStr) + next.ServeHTTP(rw, req) + }) + } +} diff --git a/internal/serve/middleware/middleware_test.go b/internal/serve/middleware/middleware_test.go new file mode 100644 index 000000000..1915edd9f --- /dev/null +++ b/internal/serve/middleware/middleware_test.go @@ -0,0 +1,617 @@ +package middleware + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/sirupsen/logrus" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_RecoverHandler(t *testing.T) { + // setup logger to assert the logged texts later + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(logrus.TraceLevel) + + // setup + r := chi.NewRouter() + r.Use(RecoverHandler) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + // test + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusInternalServerError, rr.Code) + wantJson := `{ + "error": "An internal error occurred while processing this request." + }` + assert.JSONEq(t, wantJson, rr.Body.String()) + + // assert logged text + assert.Contains(t, buf.String(), "panic: test panic", "should log the panic message") +} + +func Test_RecoverHandler_doesNotRecoverFromErrAbortHandler(t *testing.T) { + // setup logger to assert the logged texts later + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(logrus.TraceLevel) + + // setup + r := chi.NewRouter() + r.Use(RecoverHandler) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + panic(http.ErrAbortHandler) + }) + + // test + require.Panics(t, func() { + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + }, "http.ErrAbortHandler is supposed to panic") +} + +func Test_MetricsRequestHandler(t *testing.T) { + mMonitorService := &monitor.MockMonitorService{} + + // setup + r := chi.NewRouter() + r.Use(MetricsRequestHandler(mMonitorService)) + r.Get("/mock", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"status": "OK"}`)) + require.NoError(t, err) + }) + + t.Run("monitor request with valid route", func(t *testing.T) { + mLabels := monitor.HttpRequestLabels{ + Status: "200", + Route: "/mock", + Method: "GET", + } + + mMonitorService.On("MonitorHttpRequestDuration", mock.AnythingOfType("time.Duration"), mLabels).Return(nil).Once() + + // test + req, err := http.NewRequest("GET", "/mock", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusOK, rr.Code) + wantBody := `{"status": "OK"}` + assert.JSONEq(t, wantBody, rr.Body.String()) + + mMonitorService.AssertExpectations(t) + }) + + t.Run("monitor request with invalid route", func(t *testing.T) { + mLabels := monitor.HttpRequestLabels{ + Status: "404", + Route: "undefined", + Method: "GET", + } + + mMonitorService.On("MonitorHttpRequestDuration", mock.AnythingOfType("time.Duration"), mLabels).Return(nil).Once() + + // test + req, err := http.NewRequest("GET", "/invalid-route", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusNotFound, rr.Code) + + mMonitorService.AssertExpectations(t) + }) + + t.Run("monitor request with method not allowed", func(t *testing.T) { + mLabels := monitor.HttpRequestLabels{ + Status: "405", + Route: "undefined", + Method: "POST", + } + + mMonitorService.On("MonitorHttpRequestDuration", mock.AnythingOfType("time.Duration"), mLabels).Return(nil).Once() + + // test + req, err := http.NewRequest("POST", "/mock", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + // assert response + assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) + + mMonitorService.AssertExpectations(t) + }) +} + +func Test_AuthenticateMiddleware(t *testing.T) { + r := chi.NewRouter() + + jwtManagerMock := &auth.JWTManagerMock{} + authManager := auth.NewAuthManager(auth.WithCustomJWTManagerOption(jwtManagerMock)) + + r.Group(func(r chi.Router) { + r.Use(AuthenticateMiddleware(authManager)) + + r.Get("/authenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + }) + + r.Get("/unauthenticated", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + + t.Run("returns Unauthorized error when no header is sent", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns Unauthorized error when a invalid header is sent", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + + // Only one part + req.Header.Set("Authorization", "BearerToken") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Invalid Authorization header provided."}`, string(respBody)) + + req, err = http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + + // More than two parts + req.Header.Set("Authorization", "Bearer token token") + + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp = w.Result() + respBody, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Invalid Authorization header provided."}`, string(respBody)) + }) + + t.Run("returns Unauthorized when a unexpected error occurs validating the token", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + + req.Header.Set("Authorization", "Bearer token") + + jwtManagerMock. + On("ValidateToken", mock.Anything, "token"). + Return(false, errors.New("unexpected error")). + Once() + + getEntries := log.DefaultLogger.StartTest(log.ErrorLevel) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error": "Not authorized."}`, string(respBody)) + + entries := getEntries() + assert.NotEmpty(t, entries) + assert.Equal(t, `error validating auth token: validating token: unexpected error`, entries[0].Message) + }) + + t.Run("returns Unauthorized when the token is invalid", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + + req.Header.Set("Authorization", "Bearer token") + + jwtManagerMock. + On("ValidateToken", mock.Anything, "token"). + Return(false, nil). + Once() + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Invalid token provided."}`, string(respBody)) + }) + + t.Run("returns the response successfully", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/authenticated", nil) + require.NoError(t, err) + + req.Header.Set("Authorization", "Bearer token") + + jwtManagerMock. + On("ValidateToken", mock.Anything, "token"). + Return(true, nil). + Once() + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) + + t.Run("doesn't return Unauthorized for unauthenticated routes", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/unauthenticated", nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) +} + +func Test_AnyRoleMiddleware(t *testing.T) { + jwtManagerMock := &auth.JWTManagerMock{} + roleManagerMock := &auth.RoleManagerMock{} + authManager := auth.NewAuthManager( + auth.WithCustomJWTManagerOption(jwtManagerMock), + auth.WithCustomRoleManagerOption(roleManagerMock), + ) + + const url = "/restricted" + + setRestrictedEndpoint := func(ctx context.Context, r *chi.Mux, roles ...data.UserRole) { + r.With(AnyRoleMiddleware(authManager, roles...)). + Get(url, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(json.RawMessage(`{"status":"ok"}`)) + require.NoError(t, err) + }) + } + + t.Run("returns Unauthorized when no token is in the request context", func(t *testing.T) { + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, "role1", "role2") + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns Unauthorized when no token is in the request context", func(t *testing.T) { + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, "role1", "role2") + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns Unauthorized when the token is expired", func(t *testing.T) { + token := "mytoken" + ctx := context.WithValue(context.Background(), TokenContextKey, token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, "role1", "role2") + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(false, nil). + Once() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns Internal Server error when an error occurs", func(t *testing.T) { + token := "mytoken" + ctx := context.WithValue(context.Background(), TokenContextKey, token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, "role1", "role2") + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(false, errors.New("unexpected error")). + Once() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.JSONEq(t, `{"error":"An internal error occurred while processing this request."}`, string(respBody)) + }) + + t.Run("returns Unauthorized error when the user does not have the required roles", func(t *testing.T) { + token := "mytoken" + ctx := context.WithValue(context.Background(), TokenContextKey, token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + requiredRoles := []data.UserRole{data.BusinessUserRole, data.FinancialControllerUserRole} + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, requiredRoles...) + + user := &auth.User{ + ID: "user-id", + Email: "email@email.com", + Roles: []string{data.DeveloperUserRole.String()}, + } + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(true, nil). + Once(). + On("GetUserFromToken", mock.Anything, token). + Return(user, nil). + Once() + + roleManagerMock. + On("HasAnyRoles", mock.Anything, user, data.FromUserRoleArrayToStringArray(requiredRoles)). + Return(false, nil). + Once() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.JSONEq(t, `{"error":"Not authorized."}`, string(respBody)) + }) + + t.Run("returns Status Ok when user has the required roles", func(t *testing.T) { + token := "mytoken" + ctx := context.WithValue(context.Background(), TokenContextKey, token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + requiredRoles := []data.UserRole{data.BusinessUserRole, data.DeveloperUserRole} + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, requiredRoles...) + + user := &auth.User{ + ID: "user-id", + Email: "email@email", + Roles: []string{data.DeveloperUserRole.String()}, + } + + jwtManagerMock. + On("ValidateToken", mock.Anything, token). + Return(true, nil). + Once(). + On("GetUserFromToken", mock.Anything, token). + Return(user, nil). + Once() + + roleManagerMock. + On("HasAnyRoles", mock.Anything, user, data.FromUserRoleArrayToStringArray(requiredRoles)). + Return(true, nil). + Once() + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) + + t.Run("returns Status Ok when no roles is required", func(t *testing.T) { + token := "mytoken" + ctx := context.WithValue(context.Background(), TokenContextKey, token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + + requiredRoles := []data.UserRole{} + + r := chi.NewRouter() + setRestrictedEndpoint(ctx, r, requiredRoles...) + + r.ServeHTTP(w, req) + + resp := w.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.JSONEq(t, `{"status":"ok"}`, string(respBody)) + }) +} + +func Test_CorsMiddleware(t *testing.T) { + t.Run("Should work with an expected origin", func(t *testing.T) { + r := chi.NewRouter() + requestBaseURL := "http://myserver.com/*" + expectedRespBody := "ok" + + r.Use(CorsMiddleware([]string{requestBaseURL})) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(expectedRespBody)) + require.NoError(t, err) + }) + + expectedReqOrigin := "http://myserver.com/custompage" + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + req.Header.Add("Origin", expectedReqOrigin) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, expectedReqOrigin, resp.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, expectedRespBody, string(respBody)) + }) + + t.Run("Should not return Access-Control-Allow-Origin header with unexpected origin", func(t *testing.T) { + r := chi.NewRouter() + requestBaseURL := "http://myserver.com" + expectedRespBody := "ok" + + r.Use(CorsMiddleware([]string{requestBaseURL})) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(expectedRespBody)) + require.NoError(t, err) + }) + + reqOrigin := "http://locahost:8080" + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + req.Header.Add("Origin", reqOrigin) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Empty(t, resp.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, expectedRespBody, string(respBody)) + }) +} + +func Test_CSPMiddleware(t *testing.T) { + t.Run("Should populate the Content-Security-Policy header correctly", func(t *testing.T) { + r := chi.NewRouter() + expectedRespBody := "ok" + + r.Use(CSPMiddleware()) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(expectedRespBody)) + require.NoError(t, err) + }) + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + wantCSP := "script-src 'self' https://www.google.com/recaptcha/ https://www.gstatic.com/recaptcha/;style-src 'self' https://www.google.com/recaptcha/ https://fonts.googleapis.com/css2 'unsafe-inline';connect-src 'self' https://www.google.com/recaptcha/;font-src 'self' https://fonts.gstatic.com;default-src 'self';frame-src 'self' https://www.google.com/recaptcha/;frame-ancestors 'self';form-action 'self';" + gotCSP := resp.Header.Get("Content-Security-Policy") + assert.Equal(t, wantCSP, gotCSP) + assert.Equal(t, expectedRespBody, string(respBody)) + }) +} diff --git a/internal/serve/publicfiles/css/receiver_registration.css b/internal/serve/publicfiles/css/receiver_registration.css new file mode 100644 index 000000000..f5ae9bb3b --- /dev/null +++ b/internal/serve/publicfiles/css/receiver_registration.css @@ -0,0 +1,204 @@ +/* Colors */ +:root { + --color-background: #f9f8f9; + --color-text-primary: #000000; + --color-text-secondary: #6f6e77; + --color-text-tertiary: #908e96; + + --color-input-border: #eeedef; + --color-input-background: #ffffff; + --color-input-text: #1a1523; + + --color-button-primary-background: #1c1c1f; + --color-button-primary-border: #28282c; + --color-button-primary-text: #ffffff; + --color-button-secondary-background: #f4f2f4; + --color-button-secondary-border: #eeedef; + --color-button-secondary-text: #000000; + + --color-info-error-background: #ffefef; + --color-info-error-border: #ffe5e5; + --color-info-error-icon: #eb9091; + + --color-info-success-background: #e9f9ee; + --color-info-success-border: #ddf3e4; + --color-info-success-icon: #5bb98c; +} + +body { + background-color: var(--color-background); + color: var(--color-text-secondary); + font-family: "Inter", sans-serif; + font-size: 16px; + margin: 0; + padding: 0; + font-weight: 400; +} + +body * { + box-sizing: border-box; +} + +.WalletRegistration { + font-size: 0.875rem; + padding: 1.5rem; + max-width: 50rem; + margin: 0 auto; + position: relative; + min-height: 100vh; + display: flex; + flex-direction: column; +} + +.WalletRegistration__MainContent { + flex: 1; +} + +.WalletRegistration__Footer { + flex: 1; + display: flex; + flex-direction: column; + justify-content: flex-end; +} + +section { + flex: 1; + display: flex; + flex-direction: column; + justify-content: space-between; + gap: 1rem; +} + +h2 { + font-family: "Inter Tight"; + font-weight: 400; + font-size: 1.25rem; + line-height: 1.75rem; + color: var(--color-text-primary); + margin: 0; + padding: 0; + margin-bottom: 1.25rem; +} + +p { + color: var(--color-text-secondary); + font-size: 0.875rem; + line-height: 1.375rem; +} + +form { + display: flex; + flex-direction: column; + gap: 1rem; + margin-top: 1.5rem; +} + +label { + font-weight: 500; + font-size: 0.875rem; + line-height: 1.375rem; + color: var(--color-text-tertiary); +} + +input { + font-size: 1rem; + line-height: 1.5rem; + color: var(--color-input-text); + background-color: var(--color-input-background); + border: 1px solid var(--color-input-border); + border-radius: 0.25rem; + padding: 0.5rem 0.75rem; +} + +.Form__item { + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.Form__buttons { + display: flex; + align-items: center; + gap: 0.5rem; +} + +button { + font-family: "Inter"; + font-weight: 500; + font-size: 0.875rem; + line-height: 1.375rem; + border: 1px solid; + padding: 0.375rem 0.625rem; + border-radius: 0.25rem; + cursor: pointer; + transition: opacity linear 500ms; +} + +button:disabled { + opacity: 0.8; + cursor: not-allowed; +} + +button.Button--primary { + color: var(--color-button-primary-text); + background-color: var(--color-button-primary-background); + border-color: var(--color-button-primary-border); +} + +button.Button--secondary { + color: var(--color-button-secondary-text); + background-color: var(--color-button-secondary-background); + border-color: var(--color-button-secondary-border); +} + +.Notification { + display: flex; + flex-direction: column; + gap: 0.5rem; + padding: 1rem; + border: 1px solid; + border-radius: 0.25rem; +} + +.Notification__title { + display: flex; + align-items: center; + gap: 0.3rem; + color: var(--color-text-primary); + font-size: 0.875rem; + line-height: 1.375rem; + font-weight: 500; +} + +.Notification__title svg { + display: block; + flex-shrink: 0; + width: 1rem; + height: 1rem; + align-self: flex-start; + margin-top: 0.2rem; +} + +.Notification__content { + color: var(--color-text-secondary); + font-size: 0.875rem; + line-height: 1.375rem; +} + +.Notification--error { + background-color: var(--color-info-error-background); + border-color: var(--color-info-error-border); +} + +.Notification--error svg { + fill: var(--color-info-error-icon); +} + +.Notification--success { + background-color: var(--color-info-success-background); + border-color: var(--color-info-success-border); +} + +.Notification--success svg { + fill: var(--color-info-success-icon); +} diff --git a/internal/serve/publicfiles/css/text_mock.css b/internal/serve/publicfiles/css/text_mock.css new file mode 100644 index 000000000..1f69d13f6 --- /dev/null +++ b/internal/serve/publicfiles/css/text_mock.css @@ -0,0 +1 @@ +.text-blue { color: #3300ff; } diff --git a/internal/serve/publicfiles/img/logo.png b/internal/serve/publicfiles/img/logo.png new file mode 100644 index 000000000..6e3d47efe Binary files /dev/null and b/internal/serve/publicfiles/img/logo.png differ diff --git a/internal/serve/publicfiles/js/receiver_registered_successfully.js b/internal/serve/publicfiles/js/receiver_registered_successfully.js new file mode 100644 index 000000000..c76eb48fc --- /dev/null +++ b/internal/serve/publicfiles/js/receiver_registered_successfully.js @@ -0,0 +1,18 @@ +// Purpose: to let other windows know that the receiver has been registered successfully +postMessage("verified"); + +document.addEventListener("DOMContentLoaded", function() { + const button = document.getElementById('backToHomeButton'); + + button.addEventListener('click', function(event) { + backToHome(event); + }); +}); + + +function backToHome(event) { + window.close(); + + // Purpose: to let other windows know that this window has been closed + postMessage('close'); +} \ No newline at end of file diff --git a/internal/serve/publicfiles/js/receiver_registration.js b/internal/serve/publicfiles/js/receiver_registration.js new file mode 100644 index 000000000..d1409e9a4 --- /dev/null +++ b/internal/serve/publicfiles/js/receiver_registration.js @@ -0,0 +1,270 @@ +const WalletRegistration = { + jwtToken: "", +}; + +function getJwtToken() { + const tokenEl = document.querySelector("[data-jwt-token]"); + + if (tokenEl) { + return tokenEl.innerHTML; + } +} + +function toggleNotification(type, { parentEl, title, message, isVisible }) { + const titleEl = parentEl.querySelector(`[data-section-${type}-title]`); + const messageEl = parentEl.querySelector(`[data-section-${type}-message`); + + if (titleEl && messageEl) { + if (isVisible) { + parentEl.style.display = "flex"; + titleEl.innerHTML = title; + messageEl.innerHTML = message; + } else { + parentEl.style.display = "none"; + titleEl.innerHTML = ""; + messageEl.innerHTML = ""; + } + } +} + +function toggleErrorNotification(parentEl, title, message, isVisible) { + toggleNotification("error", { parentEl, title, message, isVisible }); +} + +function toggleSuccessNotification(parentEl, title, message, isVisible) { + toggleNotification("success", { parentEl, title, message, isVisible }); +} + +async function sendSms(phoneNumber, reCAPTCHAToken, onSuccess, onError) { + if (phoneNumber && reCAPTCHAToken) { + try { + const request = await fetch("/wallet-registration/otp", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${WalletRegistration.jwtToken}`, + }, + body: JSON.stringify({ + phone_number: phoneNumber, + recaptcha_token: reCAPTCHAToken, + }), + }); + await request.json(); + + onSuccess(); + } catch (error) { + onError(error); + } + } +} + +function disableButtons(buttons) { + buttons.forEach((b) => { + b.disabled = true; + }); +} + +function enableButtons(buttons) { + const t = window.setTimeout(() => { + buttons.forEach((b) => { + b.disabled = false; + }); + + clearTimeout(t); + }, 1000); +} + +document.addEventListener("DOMContentLoaded", function() { + const form = document.getElementById("submitPhoneNumberForm"); + + form.addEventListener("submit", function(event) { + submitPhoneNumber(event); + }); +}); + +async function submitPhoneNumber(event) { + event.preventDefault(); + const phoneNumberSectionEl = document.querySelector( + "[data-section='phoneNumber']", + ); + const passcodeSectionEl = document.querySelector("[data-section='passcode']"); + const errorNotificationEl = document.querySelector( + "[data-section-error='phoneNumber']", + ); + const phoneNumberEl = document.getElementById("phone_number"); + const reCAPTCHATokenEl = phoneNumberSectionEl.querySelector("#g-recaptcha-response") + const buttonEls = phoneNumberSectionEl.querySelectorAll("[data-button]"); + + if (!reCAPTCHATokenEl || !reCAPTCHATokenEl.value) { + toggleErrorNotification(errorNotificationEl, "Error", "reCAPTCHA is required", true); + return; + } + + toggleErrorNotification(errorNotificationEl, "", "", false); + + if ( + phoneNumberEl && + reCAPTCHATokenEl && + phoneNumberSectionEl && + passcodeSectionEl && + errorNotificationEl + ) { + disableButtons(buttonEls); + const phoneNumber = phoneNumberEl.value; + const reCAPTCHAToken = reCAPTCHATokenEl.value; + + function showNextPage() { + phoneNumberSectionEl.style.display = "none"; + reCAPTCHATokenEl.style.display = "none"; + passcodeSectionEl.style.display = "flex"; + enableButtons(buttonEls); + } + + function showErrorMessage(error) { + toggleErrorNotification(errorNotificationEl, "Error", error, true); + enableButtons(buttonEls); + } + + sendSms(phoneNumber, reCAPTCHAToken, showNextPage, showErrorMessage); + } +} + +document.addEventListener("DOMContentLoaded", function() { + const form = document.getElementById("submitOtpForm"); + + form.addEventListener("submit", function(event) { + submitOtp(event); + }); +}); + +async function submitOtp(event) { + event.preventDefault(); + + const passcodeSectionEl = document.querySelector("[data-section='passcode']"); + const errorNotificationEl = document.querySelector( + "[data-section-error='passcode']", + ); + const successNotificationEl = document.querySelector( + "[data-section-success='passcode']", + ); + const phoneNumberEl = document.getElementById("phone_number"); + const otpEl = document.getElementById("otp"); + const verificationEl = document.getElementById("verification"); + + const buttonEls = passcodeSectionEl.querySelectorAll("[data-button]"); + + const reCAPTCHATokenEl = passcodeSectionEl.querySelector("#g-recaptcha-response-1"); + if (!reCAPTCHATokenEl || !reCAPTCHATokenEl.value) { + toggleErrorNotification(errorNotificationEl, "Error", "reCAPTCHA is required", true); + return; + } + + if ( + phoneNumberEl && + otpEl && + verificationEl && + passcodeSectionEl && + errorNotificationEl + ) { + toggleErrorNotification(errorNotificationEl, "", "", false); + toggleSuccessNotification(successNotificationEl, "", "", false); + + const phoneNumber = phoneNumberEl.value; + const otp = otpEl.value; + const verification = verificationEl.value; + + if (phoneNumber && otp && verification) { + try { + disableButtons(buttonEls); + + const response = await fetch("/wallet-registration/verification", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${WalletRegistration.jwtToken}`, + }, + body: JSON.stringify({ + phone_number: phoneNumber, + otp: otp, + verification: verification, + verification_type: "date_of_birth", + recaptcha_token: reCAPTCHATokenEl.value, + }), + }); + + if ([200, 201].includes(response.status)) { + await response.json(); + + const t = window.setTimeout(() => { + location.reload(); + clearTimeout(t); + }, 2000); + } else { + throw new Error("Something went wrong, please try again later."); + } + } catch (error) { + enableButtons(buttonEls); + toggleErrorNotification(errorNotificationEl, "Error", error, true); + grecaptcha.reset(1); + } + } + } +} + +document.addEventListener("DOMContentLoaded", function() { + const button = document.getElementById('resendSmsButton'); + + button.addEventListener('click', function(event) { + resendSms(event); + }); +}); + +async function resendSms() { + const passcodeSectionEl = document.querySelector("[data-section='passcode']"); + const errorNotificationEl = document.querySelector( + "[data-section-error='passcode']", + ); + const successNotificationEl = document.querySelector( + "[data-section-success='passcode']", + ); + const buttonEls = passcodeSectionEl.querySelectorAll("[data-button]"); + const phoneNumberEl = document.getElementById("phone_number"); + const reCAPTCHATokenEl = passcodeSectionEl.querySelector("#g-recaptcha-response-1"); + + if (!reCAPTCHATokenEl || !reCAPTCHATokenEl.value) { + toggleErrorNotification(errorNotificationEl, "Error", "reCAPTCHA is required", true); + return; + } + + if ((passcodeSectionEl, errorNotificationEl, phoneNumberEl, reCAPTCHATokenEl)) { + disableButtons(buttonEls); + toggleErrorNotification(errorNotificationEl, "", "", false); + toggleSuccessNotification(successNotificationEl, "", "", false); + + const phoneNumber = phoneNumberEl.value; + const reCAPTCHAToken = reCAPTCHATokenEl.value; + + function showErrorMessage(error) { + toggleErrorNotification(errorNotificationEl, "Error", error, true); + enableButtons(buttonEls); + } + + function showSuccessMessage() { + toggleSuccessNotification( + successNotificationEl, + "New SMS sent", + "You will receive a new one-time passcode", + true, + ); + enableButtons(buttonEls); + } + + sendSms(phoneNumber, reCAPTCHAToken, showSuccessMessage, showErrorMessage); + grecaptcha.reset(1); + } +} + +// Init +window.onload = async () => { + WalletRegistration.jwtToken = getJwtToken(); +}; diff --git a/internal/serve/publicfiles/js/test_mock.js b/internal/serve/publicfiles/js/test_mock.js new file mode 100644 index 000000000..39565705d --- /dev/null +++ b/internal/serve/publicfiles/js/test_mock.js @@ -0,0 +1 @@ +console.log("test mock file."); diff --git a/internal/serve/publicfiles/main.go b/internal/serve/publicfiles/main.go new file mode 100644 index 000000000..5835aa750 --- /dev/null +++ b/internal/serve/publicfiles/main.go @@ -0,0 +1,6 @@ +package publicfiles + +import "embed" + +//go:embed css/* js/* img/* +var PublicFiles embed.FS diff --git a/internal/serve/serve.go b/internal/serve/serve.go new file mode 100644 index 000000000..878eb8b58 --- /dev/null +++ b/internal/serve/serve.go @@ -0,0 +1,437 @@ +package serve + +import ( + "fmt" + "io/fs" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + chimiddleware "github.com/go-chi/chi/v5/middleware" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/network" + supporthttp "github.com/stellar/go/support/http" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/anchorplatform" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httperror" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httphandler" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + publicfiles "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/publicfiles" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/validators" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + txnsubmitterutils "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +const ServiceID = "serve" + +type HTTPServerInterface interface { + Run(conf supporthttp.Config) +} + +type HTTPServer struct{} + +func (h *HTTPServer) Run(conf supporthttp.Config) { + supporthttp.Run(conf) +} + +type ServeOptions struct { + Environment string + GitCommit string + Port int + Version string + MonitorService monitor.MonitorServiceInterface + DatabaseDSN string + dbConnectionPool db.DBConnectionPool + EC256PublicKey string + EC256PrivateKey string + Models *data.Models + CorsAllowedOrigins []string + authManager auth.AuthManager + EmailMessengerClient message.MessengerClient + SMSMessengerClient message.MessengerClient + SEP24JWTSecret string + sep24JWTManager *anchorplatform.JWTManager + BaseURL string + UIBaseURL string + ResetTokenExpirationHours int + NetworkPassphrase string + HorizonURL string + horizonClient horizonclient.ClientInterface + signatureService engine.SignatureService + Sep10SigningPublicKey string + Sep10SigningPrivateKey string + AnchorPlatformBasePlatformURL string + AnchorPlatformBaseSepURL string + AnchorPlatformOutgoingJWTSecret string + anchorPlatformAPIService anchorplatform.AnchorPlatformAPIServiceInterface + CrashTrackerClient crashtracker.CrashTrackerClient + DistributionPublicKey string + DistributionSeed string + ReCAPTCHASiteKey string + ReCAPTCHASiteSecretKey string + EnableMFA bool + EnableReCAPTCHA bool +} + +// SetupDependencies uses the serve options to setup the dependencies for the server. +func (opts *ServeOptions) SetupDependencies() error { + // Setup crash tracker: + // Call crash tracker FlushEvents to flush buffered events before the server terminates + defer opts.CrashTrackerClient.FlushEvents(2 * time.Second) + // Call crash tracker Recover for recover from unhandled panics + defer opts.CrashTrackerClient.Recover() + // Set crash tracker LogAndReportErrors as DefaultReportErrorFunc + httperror.SetDefaultReportErrorFunc(opts.CrashTrackerClient.LogAndReportErrors) + + // Setup Database: + dbConnectionPool, err := db.OpenDBConnectionPoolWithMetrics(opts.DatabaseDSN, opts.MonitorService) + if err != nil { + return fmt.Errorf("error connecting to the database: %w", err) + } + opts.Models, err = data.NewModels(dbConnectionPool) + if err != nil { + return fmt.Errorf("error creating models for Serve: %w", err) + } + opts.dbConnectionPool = dbConnectionPool + + // Setup Stellar Auth JWT manager + opts.authManager, err = createAuthManager( + opts.dbConnectionPool, opts.EC256PublicKey, opts.EC256PrivateKey, opts.ResetTokenExpirationHours, + ) + if err != nil { + return fmt.Errorf("error creating Stellar Auth manager: %w", err) + } + + // Setup Anchor Platform SEP24 JWT manager + sep24JWTManager, err := anchorplatform.NewJWTManager(opts.SEP24JWTSecret, 15000) + if err != nil { + return fmt.Errorf("error creating SEP-24 JWT manager: %w", err) + } + opts.sep24JWTManager = sep24JWTManager + + // Setup Anchor Platform API Service + opts.anchorPlatformAPIService, err = anchorplatform.NewAnchorPlatformAPIService(httpclient.DefaultClient(), opts.AnchorPlatformBasePlatformURL, opts.AnchorPlatformOutgoingJWTSecret) + if err != nil { + return fmt.Errorf("error creating Anchor Platform API service: %w", err) + } + + // Setup Horizon Client + opts.horizonClient = &horizonclient.Client{ + HorizonURL: opts.HorizonURL, + HTTP: httpclient.DefaultClient(), + } + + // Setup Signature Service + // TODO: improve the way we setup signature service + opts.signatureService, err = engine.NewDefaultSignatureService( + opts.NetworkPassphrase, + dbConnectionPool, + opts.DistributionSeed, + store.NewChannelAccountModel(opts.dbConnectionPool), + txnsubmitterutils.DefaultPrivateKeyEncrypter{}, + opts.DistributionSeed, + ) + if err != nil { + return fmt.Errorf("error creating signature service: %w", err) + } + + return nil +} + +func Serve(opts ServeOptions, httpServer HTTPServerInterface) error { + err := opts.SetupDependencies() + if err != nil { + return fmt.Errorf("error starting dependencies: %w", err) + } + + // Start the server + listenAddr := fmt.Sprintf(":%d", opts.Port) + serverConfig := supporthttp.Config{ + ListenAddr: listenAddr, + Handler: handleHTTP(opts), + TCPKeepAlive: time.Minute * 3, + ShutdownGracePeriod: time.Second * 50, + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 35, + IdleTimeout: time.Minute * 2, + OnStarting: func() { + log.Info("Starting SDP (Stellar Disbursement Platform) Server") + log.Infof("Listening on %s", listenAddr) + }, + OnStopping: func() { + log.Info("Closing the database connection...") + err := opts.dbConnectionPool.Close() + if err != nil { + log.Errorf("error closing database connection: %s", err.Error()) + } + + log.Info("Stopping SDP (Stellar Disbursement Platform) Server") + }, + } + httpServer.Run(serverConfig) + return nil +} + +func handleHTTP(o ServeOptions) *chi.Mux { + mux := chi.NewMux() + + // Middleware + mux.Use(middleware.CorsMiddleware(o.CorsAllowedOrigins)) + mux.Use(chimiddleware.RequestID) + mux.Use(chimiddleware.RealIP) + mux.Use(supporthttp.LoggingMiddleware) + mux.Use(middleware.RecoverHandler) + mux.Use(middleware.MetricsRequestHandler(o.MonitorService)) + mux.Use(middleware.CSPMiddleware()) + + // Create a route along /static that will serve contents from the ./public_files folder. + staticFileServer(mux, publicfiles.PublicFiles) + + // Authenticated Routes + authManager := o.authManager + mux.Group(func(r chi.Router) { + r.Use(middleware.AuthenticateMiddleware(authManager)) + + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).Route("/statistics", func(r chi.Router) { + statisticsHandler := httphandler.StatisticsHandler{DBConnectionPool: o.dbConnectionPool} + r.Get("/", statisticsHandler.GetStatistics) + r.Get("/{id}", statisticsHandler.GetStatisticsByDisbursement) + }) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole)).Route("/users", func(r chi.Router) { + userHandler := httphandler.UserHandler{ + AuthManager: authManager, + MessengerClient: o.EmailMessengerClient, + UIBaseURL: o.UIBaseURL, + Models: o.Models, + } + + r.Get("/", userHandler.GetAllUsers) + r.Post("/", userHandler.CreateUser) + r.Get("/roles", httphandler.ListRolesHandler{}.GetRoles) + r.Patch("/roles", userHandler.UpdateUserRoles) + r.Patch("/activation", userHandler.UserActivation) + }) + r.Post("/refresh-token", httphandler.RefreshTokenHandler{AuthManager: authManager}.PostRefreshToken) + + r.Route("/disbursements", func(r chi.Router) { + handler := httphandler.DisbursementHandler{ + Models: o.Models, + MonitorService: o.MonitorService, + DBConnectionPool: o.dbConnectionPool, + AuthManager: authManager, + } + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Post("/", handler.PostDisbursement) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Post("/{id}/instructions", handler.PostDisbursementInstructions) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Get("/{id}/instructions", handler.GetDisbursementInstructions) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.BusinessUserRole)). + Get("/", handler.GetDisbursements) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.BusinessUserRole)). + Get("/{id}", handler.GetDisbursement) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.BusinessUserRole)). + Get("/{id}/receivers", handler.GetDisbursementReceivers) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Patch("/{id}/status", handler.PatchDisbursementStatus) + }) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.BusinessUserRole)).Route("/payments", func(r chi.Router) { + paymentsHandler := httphandler.PaymentsHandler{Models: o.Models, DBConnectionPool: o.dbConnectionPool, AuthManager: o.authManager} + r.Get("/", paymentsHandler.GetPayments) + r.Get("/{id}", paymentsHandler.GetPayment) + r.Patch("/retry", paymentsHandler.RetryPayments) + }) + + r.Route("/receivers", func(r chi.Router) { + receiversHandler := httphandler.ReceiverHandler{Models: o.Models, DBConnectionPool: o.dbConnectionPool} + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.BusinessUserRole)). + Get("/", receiversHandler.GetReceivers) + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Get("/{id}", receiversHandler.GetReceiver) + + updateReceiverHandler := httphandler.UpdateReceiverHandler{Models: o.Models, DBConnectionPool: o.dbConnectionPool} + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Patch("/{id}", updateReceiverHandler.UpdateReceiver) + }) + + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).Route("/countries", func(r chi.Router) { + r.Get("/", httphandler.CountriesHandler{Models: o.Models}.GetCountries) + }) + + r.Route("/assets", func(r chi.Router) { + assetsHandler := httphandler.AssetsHandler{ + Models: o.Models, + SignatureService: o.signatureService, + HorizonClient: o.horizonClient, + } + + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)). + Get("/", assetsHandler.GetAssets) + + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.DeveloperUserRole)). + Post("/", assetsHandler.CreateAsset) + + r.Route("/{id}", func(r chi.Router) { + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole, data.DeveloperUserRole)).Delete("/", assetsHandler.DeleteAsset) + }) + }) + + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)).Route("/wallets", func(r chi.Router) { + r.Get("/", httphandler.WalletsHandler{Models: o.Models}.GetWallets) + }) + + profileHandler := httphandler.ProfileHandler{ + Models: o.Models, + AuthManager: authManager, + MaxMemoryAllocation: httphandler.DefaultMaxMemoryAllocation, + BaseURL: o.BaseURL, + DistributionPublicKey: o.DistributionPublicKey, + } + r.Route("/profile", func(r chi.Router) { + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)). + Get("/", profileHandler.GetProfile) + + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)). + Patch("/", profileHandler.PatchUserProfile) + }) + + r.Route("/organization", func(r chi.Router) { + r.With(middleware.AnyRoleMiddleware(authManager, data.OwnerUserRole, data.FinancialControllerUserRole)). + Patch("/", profileHandler.PatchOrganizationProfile) + + r.With(middleware.AnyRoleMiddleware(authManager, data.GetAllRoles()...)). + Get("/", profileHandler.GetOrganizationInfo) + }) + }) + + // Even if the logo URL is under the public endpoints, it'll be authenticated. The `auth token` should be + // added in the URL's query params. Example: https://...?token=mytoken + mux.Get("/organization/logo", httphandler.ProfileHandler{Models: o.Models, PublicFilesFS: publicfiles.PublicFiles}.GetOrganizationLogo) + + mux.Get("/health", httphandler.HealthHandler{ + ReleaseID: o.GitCommit, + ServiceID: ServiceID, + Version: o.Version, + }.ServeHTTP) + + reCAPTCHAValidator := validators.NewGoogleReCAPTCHAValidator(o.ReCAPTCHASiteSecretKey, httpclient.DefaultClient()) + + mux.Post("/login", httphandler.LoginHandler{ + AuthManager: authManager, + ReCAPTCHAValidator: reCAPTCHAValidator, + MessengerClient: o.EmailMessengerClient, + Models: o.Models, + ReCAPTCHAEnabled: o.EnableReCAPTCHA, + MFAEnabled: o.EnableMFA, + }.ServeHTTP) + mux.Post("/mfa", httphandler.MFAHandler{ + AuthManager: authManager, + ReCAPTCHAValidator: reCAPTCHAValidator, + Models: o.Models, + }.ServeHTTP) + mux.Post("/forgot-password", httphandler.ForgotPasswordHandler{ + AuthManager: authManager, + MessengerClient: o.EmailMessengerClient, + UIBaseURL: o.UIBaseURL, + Models: o.Models, + ReCAPTCHAValidator: reCAPTCHAValidator, + ReCAPTCHAEnabled: o.EnableReCAPTCHA, + }.ServeHTTP) + mux.Post("/reset-password", httphandler.ResetPasswordHandler{AuthManager: authManager}.ServeHTTP) + + // START SEP-24 endpoints + mux.Get("/.well-known/stellar.toml", httphandler.StellarTomlHandler{ + AnchorPlatformBaseSepURL: o.AnchorPlatformBaseSepURL, + DistributionPublicKey: o.DistributionPublicKey, + NetworkPassphrase: o.NetworkPassphrase, + Models: o.Models, + Sep10SigningPublicKey: o.Sep10SigningPublicKey, + }.ServeHTTP) + + mux.Route("/wallet-registration", func(r chi.Router) { + sep24QueryTokenAuthenticationMiddleware := anchorplatform.SEP24QueryTokenAuthenticateMiddleware(o.sep24JWTManager, o.NetworkPassphrase) + r.With(sep24QueryTokenAuthenticationMiddleware).Get("/start", httphandler.ReceiverRegistrationHandler{ReceiverWalletModel: o.Models.ReceiverWallet, ReCAPTCHASiteKey: o.ReCAPTCHASiteKey}.ServeHTTP) // This loads the SEP-24 PII registration webpage. + + sep24HeaderTokenAuthenticationMiddleware := anchorplatform.SEP24HeaderTokenAuthenticateMiddleware(o.sep24JWTManager, o.NetworkPassphrase) + r.With(sep24HeaderTokenAuthenticationMiddleware).Post("/otp", httphandler.ReceiverSendOTPHandler{Models: o.Models, SMSMessengerClient: o.SMSMessengerClient, ReCAPTCHAValidator: reCAPTCHAValidator}.ServeHTTP) + r.With(sep24HeaderTokenAuthenticationMiddleware).Post("/verification", httphandler.VerifyReceiverRegistrationHandler{ + AnchorPlatformAPIService: o.anchorPlatformAPIService, + Models: o.Models, + ReCAPTCHAValidator: reCAPTCHAValidator, + NetworkPassphrase: o.NetworkPassphrase, + }.VerifyReceiverRegistration) + + // This will be used for test purposes and will only be available when IsPubnet is false: + if o.NetworkPassphrase == network.TestNetworkPassphrase { + r.Delete("/phone-number/{phone_number}", httphandler.DeletePhoneNumberHandler{Models: o.Models, NetworkPassphrase: o.NetworkPassphrase}.ServeHTTP) + } + }) + // END SEP-24 endpoints + + return mux +} + +// createAuthManager builds the default AuthManager struct to be injected +// in all the authentication related routes. +func createAuthManager(dbConnectionPool db.DBConnectionPool, ec256PublicKey, ec256PrivateKey string, resetTokenExpirationHours int) (auth.AuthManager, error) { + if dbConnectionPool == nil { + return nil, fmt.Errorf("db connection pool cannot be nil") + } + + err := utils.ValidateECDSAKeys(ec256PublicKey, ec256PrivateKey) + if err != nil { + return nil, fmt.Errorf("validating auth manager keys: %w", err) + } + + if resetTokenExpirationHours < 1 { + return nil, fmt.Errorf("reset token expiration hours must be greater than 0") + } + + passwordEncrypter := auth.NewDefaultPasswordEncrypter() + + authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName()) + authManager := auth.NewAuthManager( + auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(resetTokenExpirationHours)), + auth.WithDefaultJWTManagerOption(ec256PublicKey, ec256PrivateKey), + auth.WithDefaultRoleManagerOption(authDBConnectionPool, data.OwnerUserRole.String()), + auth.WithDefaultMFAManagerOption(authDBConnectionPool), + ) + + return authManager, nil +} + +// staticFileServer sets up a http.FileServer handler to serve +// static files from publicFiles embed FileSystem. +func staticFileServer(r chi.Router, fileSystem fs.FS) { + r.Get("/static/*", func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") + + // Don't allow users to list directories + if r.URL.Path[len(r.URL.Path)-1] == '/' { + http.NotFound(w, r) + return + } + + fs := http.StripPrefix(pathPrefix, http.FileServer(http.FS(fileSystem))) + fs.ServeHTTP(w, r) + }) +} diff --git a/internal/serve/serve_metrics.go b/internal/serve/serve_metrics.go new file mode 100644 index 000000000..584a3111c --- /dev/null +++ b/internal/serve/serve_metrics.go @@ -0,0 +1,52 @@ +package serve + +import ( + "fmt" + "time" + + "github.com/go-chi/chi/v5" + supporthttp "github.com/stellar/go/support/http" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" +) + +type MetricsServeOptions struct { + Port int + Environment string + + MonitorService monitor.MonitorServiceInterface + MetricType monitor.MetricType +} + +func MetricsServe(opts MetricsServeOptions, httpServer HTTPServerInterface) error { + metricsAddr := fmt.Sprintf(":%d", opts.Port) + metricsServerConfig := supporthttp.Config{ + ListenAddr: metricsAddr, + Handler: handleMetricsHttp(opts), + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 2 * time.Minute, + OnStarting: func() { + log.Infof("Starting %s Metrics Server", opts.MetricType) + log.Infof("Listening on %s", metricsAddr) + }, + OnStopping: func() { + log.Infof("Stopping %s Metrics Server", opts.MetricType) + }, + } + + httpServer.Run(metricsServerConfig) + return nil +} + +func handleMetricsHttp(opts MetricsServeOptions) *chi.Mux { + mux := chi.NewMux() + + metricHttpHandler, err := opts.MonitorService.GetMetricHttpHandler() + if err != nil { + log.Fatalf("Error getting metric http.handler: %s", err.Error()) + } + + mux.Handle("/metrics", metricHttpHandler) + return mux +} diff --git a/internal/serve/serve_metrics_test.go b/internal/serve/serve_metrics_test.go new file mode 100644 index 000000000..f34570c6b --- /dev/null +++ b/internal/serve/serve_metrics_test.go @@ -0,0 +1,47 @@ +package serve + +import ( + "net/http" + "testing" + "time" + + supporthttp "github.com/stellar/go/support/http" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_ServeMetrics(t *testing.T) { + mMonitorService := &monitor.MockMonitorService{} + + mMonitorService.On("GetMetricHttpHandler"). + Return(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }), nil).Twice() + + opts := MetricsServeOptions{ + Port: 8002, + MetricType: "MOCKMETRICTYPE", + MonitorService: mMonitorService, + } + + // Mock supportHTTPRun + mHTTPServer := mockHTTPServer{} + mHTTPServer.On("Run", mock.AnythingOfType("http.Config")).Run(func(args mock.Arguments) { + conf, ok := args.Get(0).(supporthttp.Config) + require.True(t, ok, "should be of type supporthttp.Config") + assert.Equal(t, ":8002", conf.ListenAddr) + assert.Equal(t, time.Second*5, conf.ReadTimeout) + assert.Equal(t, time.Second*10, conf.WriteTimeout) + assert.Equal(t, time.Minute*2, conf.IdleTimeout) + assert.Nil(t, conf.TLS) + assert.ObjectsAreEqualValues(handleMetricsHttp(opts), conf.Handler) + }).Once() + + // test and assert + err := MetricsServe(opts, &mHTTPServer) + require.NoError(t, err) + mHTTPServer.AssertExpectations(t) + mMonitorService.AssertExpectations(t) +} diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go new file mode 100644 index 000000000..3c9155c05 --- /dev/null +++ b/internal/serve/serve_test.go @@ -0,0 +1,386 @@ +package serve + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + supporthttp "github.com/stellar/go/support/http" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + publicfiles "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/publicfiles" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockHTTPServer struct { + mock.Mock +} + +func (m *mockHTTPServer) Run(conf supporthttp.Config) { + m.Called(conf) +} + +const ( + publicKeyStr = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER88h7AiQyVDysRTxKvBB6CaiO/kS +cvGyimApUE/12gFhNTRf37SE19CSCllKxstnVFOpLLWB7Qu5OJ0Wvcz3hg== +-----END PUBLIC KEY-----` + privateKeyStr = `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIqI1MzMZIw2pQDLx +Jn0+FcNT/hNjwtn2TW43710JKZqhRANCAARHzyHsCJDJUPKxFPEq8EHoJqI7+RJy +8bKKYClQT/XaAWE1NF/ftITX0JIKWUrGy2dUU6kstYHtC7k4nRa9zPeG +-----END PRIVATE KEY-----` +) + +func Test_Serve(t *testing.T) { + dbt := dbtest.OpenWithoutMigrations(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + mockCrashTrackerClient := &crashtracker.MockCrashTrackerClient{} + + opts := ServeOptions{ + CrashTrackerClient: mockCrashTrackerClient, + DatabaseDSN: dbt.DSN, + EC256PrivateKey: privateKeyStr, + EC256PublicKey: publicKeyStr, + Environment: "test", + GitCommit: "1234567890abcdef", + Models: models, + Port: 8000, + ResetTokenExpirationHours: 1, + SEP24JWTSecret: "jwt_secret_1234567890", + Version: "x.y.z", + NetworkPassphrase: network.TestNetworkPassphrase, + DistributionSeed: keypair.MustRandom().Seed(), + } + + // Mock supportHTTPRun + mHTTPServer := mockHTTPServer{} + mHTTPServer.On("Run", mock.AnythingOfType("http.Config")).Run(func(args mock.Arguments) { + conf, ok := args.Get(0).(supporthttp.Config) + require.True(t, ok, "should be of type supporthttp.Config") + assert.Equal(t, ":8000", conf.ListenAddr) + assert.Equal(t, time.Minute*3, conf.TCPKeepAlive) + assert.Equal(t, time.Second*50, conf.ShutdownGracePeriod) + assert.Equal(t, time.Second*5, conf.ReadTimeout) + assert.Equal(t, time.Second*35, conf.WriteTimeout) + assert.Equal(t, time.Minute*2, conf.IdleTimeout) + assert.Nil(t, conf.TLS) + assert.ObjectsAreEqualValues(handleHTTP(opts), conf.Handler) + conf.OnStopping() + }).Once() + mockCrashTrackerClient.On("FlushEvents", 2*time.Second).Return(false).Once() + mockCrashTrackerClient.On("Recover").Once() + + // test and assert + err = Serve(opts, &mHTTPServer) + require.NoError(t, err) + mHTTPServer.AssertExpectations(t) + mockCrashTrackerClient.AssertExpectations(t) +} + +func Test_handleHTTP_Health(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + mMonitorService := &monitor.MockMonitorService{} + mLabels := monitor.HttpRequestLabels{ + Status: "200", + Route: "/health", + Method: "GET", + } + mMonitorService.On("MonitorHttpRequestDuration", mock.AnythingOfType("time.Duration"), mLabels).Return(nil).Once() + + handlerMux := handleHTTP(ServeOptions{ + EC256PrivateKey: privateKeyStr, + EC256PublicKey: publicKeyStr, + Environment: "test", + GitCommit: "1234567890abcdef", + Models: models, + MonitorService: mMonitorService, + SEP24JWTSecret: "jwt_secret_1234567890", + Version: "x.y.z", + }) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + handlerMux.ServeHTTP(w, req) + + resp := w.Result() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + wantBody := `{ + "status": "pass", + "version": "x.y.z", + "service_id": "serve", + "release_id": "1234567890abcdef" + }` + assert.JSONEq(t, wantBody, string(body)) + mMonitorService.AssertExpectations(t) +} + +func Test_staticFileServer(t *testing.T) { + r := chi.NewMux() + + staticFileServer(r, publicfiles.PublicFiles) + + t.Run("Should return not found when tryig to access a folder", func(t *testing.T) { + req, err := http.NewRequest("GET", "/static/", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusNotFound, rr.Code) + assert.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) + assert.Equal(t, "404 page not found\n", string(data)) + }) + + t.Run("Should return file contents on a valid file", func(t *testing.T) { + req, err := http.NewRequest("GET", "/static/js/test_mock.js", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + resp := rr.Result() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Header().Get("Content-Type"), "javascript") + assert.Equal(t, "console.log(\"test mock file.\");\n", string(data)) + }) +} + +// getServeOptionsForTests returns an instance of ServeOptions for testing purposes. +// 🚨 Don't forget to call `defer serveOptions.dbConnectionPool.Close()` in your test 🚨. +func getServeOptionsForTests(t *testing.T, databaseDSN string) ServeOptions { + t.Helper() + + mMonitorService := &monitor.MockMonitorService{} + mMonitorService.On("MonitorHttpRequestDuration", mock.AnythingOfType("time.Duration"), mock.Anything).Return(nil) + + messengerClientMock := message.MessengerClientMock{} + messengerClientMock.On("SendMessage", mock.Anything).Return(nil) + + crasTrackerClient, err := crashtracker.NewDryRunClient() + require.NoError(t, err) + + serveOptions := ServeOptions{ + CrashTrackerClient: crasTrackerClient, + DatabaseDSN: databaseDSN, + EC256PrivateKey: privateKeyStr, + EC256PublicKey: publicKeyStr, + EmailMessengerClient: &messengerClientMock, + Environment: "test", + GitCommit: "1234567890abcdef", + MonitorService: mMonitorService, + ResetTokenExpirationHours: 1, + SEP24JWTSecret: "jwt_secret_1234567890", + SMSMessengerClient: &messengerClientMock, + Version: "x.y.z", + NetworkPassphrase: network.TestNetworkPassphrase, + DistributionSeed: keypair.MustRandom().Seed(), + } + err = serveOptions.SetupDependencies() + require.NoError(t, err) + + return serveOptions +} + +func Test_handleHTTP_unauthenticatedEndpoints(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + serveOptions := getServeOptionsForTests(t, dbt.DSN) + defer serveOptions.dbConnectionPool.Close() + + handlerMux := handleHTTP(serveOptions) + + // Unauthenticated endpoints + unauthenticatedEndpoints := []struct { // TODO: body to requests + method string + path string + }{ + {http.MethodGet, "/health"}, + {http.MethodPost, "/login"}, + {http.MethodPost, "/forgot-password"}, + {http.MethodPost, "/reset-password"}, + } + for _, endpoint := range unauthenticatedEndpoints { + t.Run(fmt.Sprintf("%s %s", endpoint.method, endpoint.path), func(t *testing.T) { + req := httptest.NewRequest(endpoint.method, endpoint.path, nil) + w := httptest.NewRecorder() + handlerMux.ServeHTTP(w, req) + + resp := w.Result() + assert.Contains(t, []int{http.StatusOK, http.StatusBadRequest}, resp.StatusCode) + }) + } +} + +func Test_handleHTTP_authenticatedEndpoints(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + serveOptions := getServeOptionsForTests(t, dbt.DSN) + defer serveOptions.dbConnectionPool.Close() + + handlerMux := handleHTTP(serveOptions) + + // Unauthenticated endpoints + authenticatedEndpoints := []struct { // TODO: body to requests + method string + path string + }{ + // Statistics + {http.MethodGet, "/statistics"}, + {http.MethodGet, "/statistics/1234"}, + // Users + {http.MethodGet, "/users"}, + {http.MethodPost, "/users"}, + {http.MethodGet, "/users/roles"}, + {http.MethodPatch, "/users/roles"}, + {http.MethodPatch, "/users/activation"}, + // Refresh Token + {http.MethodPost, "/refresh-token"}, + // Disbursements + {http.MethodPost, "/disbursements"}, + {http.MethodPost, "/disbursements/1234/instructions"}, + {http.MethodGet, "/disbursements/1234/instructions"}, + {http.MethodGet, "/disbursements"}, + {http.MethodGet, "/disbursements/1234"}, + {http.MethodGet, "/disbursements/1234/receivers"}, + {http.MethodGet, "/disbursements/1234/status"}, + // Payments + {http.MethodGet, "/payments"}, + {http.MethodGet, "/payments/1234"}, + {http.MethodPatch, "/payments/retry"}, + // Receivers + {http.MethodGet, "/receivers"}, + {http.MethodGet, "/receivers/1234"}, + {http.MethodPatch, "/receivers/1234"}, + // Countries + {http.MethodGet, "/countries"}, + // Assets + {http.MethodGet, "/assets"}, + {http.MethodPost, "/assets"}, + {http.MethodPatch, "/assets/1234"}, + {http.MethodDelete, "/assets/1234"}, + // Profile + {http.MethodGet, "/profile"}, + {http.MethodPatch, "/profile"}, + // Organization + {http.MethodGet, "/organization"}, + {http.MethodPatch, "/organization"}, + } + + // Expect 401 as a response: + for _, endpoint := range authenticatedEndpoints { + t.Run(fmt.Sprintf("expect 401 for %s %s", endpoint.method, endpoint.path), func(t *testing.T) { + req := httptest.NewRequest(endpoint.method, endpoint.path, nil) + w := httptest.NewRecorder() + handlerMux.ServeHTTP(w, req) + + resp := w.Result() + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) + } +} + +func Test_createAuthManager(t *testing.T) { + dbt := dbtest.OpenWithoutMigrations(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + // creates the expected auth manager + passwordEncrypter := auth.NewDefaultPasswordEncrypter() + authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName()) + wantAuthManager := auth.NewAuthManager( + auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(1)), + auth.WithDefaultJWTManagerOption(publicKeyStr, privateKeyStr), + auth.WithDefaultRoleManagerOption(authDBConnectionPool, data.OwnerUserRole.String()), + auth.WithDefaultMFAManagerOption(authDBConnectionPool), + ) + + testCases := []struct { + name string + dbConnectionPool db.DBConnectionPool + ec256PublicKey string + ec256PrivateKey string + resetTokenExpirationHours int + wantErrContains string + wantAuthManager auth.AuthManager + }{ + { + name: "returns error if dbConnectionPool is nil", + wantErrContains: "db connection pool cannot be nil", + }, + { + name: "returns error if dbConnectionPool is valid but the keypair is not", + dbConnectionPool: dbConnectionPool, + wantErrContains: "validating auth manager keys: validating ECDSA public key: failed to decode PEM block containing public key", + }, + { + name: "returns error if dbConnectionPool and keypair is valid but the resetTokenExpirationHours is not", + dbConnectionPool: dbConnectionPool, + ec256PublicKey: publicKeyStr, + ec256PrivateKey: privateKeyStr, + wantErrContains: "reset token expiration hours must be greater than 0", + }, + { + name: "πŸŽ‰ successfully create the auth manager", + dbConnectionPool: dbConnectionPool, + ec256PublicKey: publicKeyStr, + ec256PrivateKey: privateKeyStr, + resetTokenExpirationHours: 1, + wantAuthManager: wantAuthManager, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotAuthManager, err := createAuthManager( + tc.dbConnectionPool, tc.ec256PublicKey, tc.ec256PrivateKey, tc.resetTokenExpirationHours, + ) + if tc.wantErrContains != "" { + assert.Contains(t, tc.wantErrContains, err.Error()) + assert.Empty(t, gotAuthManager) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantAuthManager, gotAuthManager) + } + }) + } +} diff --git a/internal/serve/validators/disbursement_instructions_validator.go b/internal/serve/validators/disbursement_instructions_validator.go new file mode 100644 index 000000000..71ee96656 --- /dev/null +++ b/internal/serve/validators/disbursement_instructions_validator.go @@ -0,0 +1,47 @@ +package validators + +import ( + "fmt" + "strings" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type DisbursementInstructionsValidator struct { + verificationField data.VerificationField + *Validator +} + +func NewDisbursementInstructionsValidator(verificationField data.VerificationField) *DisbursementInstructionsValidator { + return &DisbursementInstructionsValidator{ + verificationField: verificationField, + Validator: NewValidator(), + } +} + +func (iv *DisbursementInstructionsValidator) ValidateInstruction(instruction *data.DisbursementInstruction, lineNumber int) { + phone := strings.TrimSpace(instruction.Phone) + id := strings.TrimSpace(instruction.ID) + amount := strings.TrimSpace(instruction.Amount) + verification := strings.TrimSpace(instruction.VerificationValue) + + // validate phone field + iv.CheckError(utils.ValidatePhoneNumber(phone), fmt.Sprintf("line %d - phone", lineNumber), "invalid phone format. Correct format: +380445555555") + iv.Check(strings.TrimSpace(phone) != "", fmt.Sprintf("line %d - phone", lineNumber), "phone cannot be empty") + + // validate id field + iv.Check(strings.TrimSpace(id) != "", fmt.Sprintf("line %d - id", lineNumber), "id cannot be empty") + + // validate amount field + iv.CheckError(utils.ValidateAmount(amount), fmt.Sprintf("line %d - amount", lineNumber), "invalid amount. Amount must be a positive number") + + // validate verification field + // date of birth with format 2006-01-02 + if iv.verificationField == data.VerificationFieldDateOfBirth { + _, err := time.Parse("2006-01-02", verification) + iv.CheckError(err, fmt.Sprintf("line %d - birthday", lineNumber), "invalid date of birth format. Correct format: 1990-01-01") + } +} diff --git a/internal/serve/validators/disbursement_instructions_validator_test.go b/internal/serve/validators/disbursement_instructions_validator_test.go new file mode 100644 index 000000000..95e53cd79 --- /dev/null +++ b/internal/serve/validators/disbursement_instructions_validator_test.go @@ -0,0 +1,124 @@ +package validators + +import ( + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" +) + +func Test_DisbursementInstructionsValidator_ValidateAndGetInstruction(t *testing.T) { + tests := []struct { + name string + actual *data.DisbursementInstruction + lineNumber int + hasErrors bool + expectedErrors map[string]interface{} + }{ + { + name: "valid record", + actual: &data.DisbursementInstruction{ + Phone: "+380445555555", + ID: "123456789", + Amount: "100.5", + VerificationValue: "1990-01-01", + }, + lineNumber: 1, + hasErrors: false, + }, + { + name: "empty phone number", + actual: &data.DisbursementInstruction{ + ID: "123456789", + Amount: "100.5", + VerificationValue: "1990-01-01", + }, + lineNumber: 2, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "line 2 - phone": "phone cannot be empty", + }, + }, + { + name: "empty phone, id, amount and birthday", + actual: &data.DisbursementInstruction{}, + lineNumber: 2, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "line 2 - amount": "invalid amount. Amount must be a positive number", + "line 2 - birthday": "invalid date of birth format. Correct format: 1990-01-01", + "line 2 - id": "id cannot be empty", + "line 2 - phone": "phone cannot be empty", + }, + }, + { + name: "invalid phone number", + actual: &data.DisbursementInstruction{ + Phone: "+123-12-345-678", + ID: "123456789", + Amount: "100.5", + VerificationValue: "1990-01-01", + }, + lineNumber: 2, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "line 2 - phone": "invalid phone format. Correct format: +380445555555", + }, + }, + { + name: "invalid amount format", + actual: &data.DisbursementInstruction{ + Phone: "+380445555555", + ID: "123456789", + Amount: "100.5USDC", + VerificationValue: "1990-01-01", + }, + lineNumber: 3, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "line 3 - amount": "invalid amount. Amount must be a positive number", + }, + }, + { + name: "amount must be positive", + actual: &data.DisbursementInstruction{ + Phone: "+380445555555", + ID: "123456789", + Amount: "-100.5", + VerificationValue: "1990-01-01", + }, + lineNumber: 3, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "line 3 - amount": "invalid amount. Amount must be a positive number", + }, + }, + { + name: "invalid birthday format", + actual: &data.DisbursementInstruction{ + Phone: "+380445555555", + ID: "123456789", + Amount: "100.5", + VerificationValue: "1990/01/01", + }, + lineNumber: 3, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "line 3 - birthday": "invalid date of birth format. Correct format: 1990-01-01", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iv := NewDisbursementInstructionsValidator(data.VerificationFieldDateOfBirth) + iv.ValidateInstruction(tt.actual, tt.lineNumber) + + if tt.hasErrors { + assert.Equal(t, tt.expectedErrors, iv.Errors) + } else { + assert.Empty(t, iv.Errors) + } + }) + } +} diff --git a/internal/serve/validators/disbursement_query_validator.go b/internal/serve/validators/disbursement_query_validator.go new file mode 100644 index 000000000..a50c166a6 --- /dev/null +++ b/internal/serve/validators/disbursement_query_validator.go @@ -0,0 +1,69 @@ +package validators + +import ( + "strings" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type DisbursementQueryValidator struct { + QueryValidator +} + +// NewDisbursementQueryValidator creates a new DisbursementQueryValidator with the provided configuration. +func NewDisbursementQueryValidator() *DisbursementQueryValidator { + return &DisbursementQueryValidator{ + QueryValidator: QueryValidator{ + DefaultSortField: data.DefaultDisbursementSortField, + DefaultSortOrder: data.DefaultDisbursementSortOrder, + AllowedSortFields: data.AllowedDisbursementSorts, + AllowedFilters: data.AllowedDisbursementFilters, + Validator: NewValidator(), + }, + } +} + +// ValidateAndGetDisbursementFilters validates the filters and returns a map of valid filters. +func (qv *DisbursementQueryValidator) ValidateAndGetDisbursementFilters(filters map[data.FilterKey]interface{}) map[data.FilterKey]interface{} { + validFilters := make(map[data.FilterKey]interface{}) + if filters[data.FilterKeyStatus] != nil { + validFilters[data.FilterKeyStatus] = qv.validateAndGetDisbursementStatuses(filters[data.FilterKeyStatus].(string)) + } + + createdAtAfter := qv.ValidateAndGetTimeParams(string(data.FilterKeyCreatedAtAfter), filters[data.FilterKeyCreatedAtAfter]) + createdAtBefore := qv.ValidateAndGetTimeParams(string(data.FilterKeyCreatedAtBefore), filters[data.FilterKeyCreatedAtBefore]) + + if qv.HasErrors() { + return validFilters + } + + if !createdAtAfter.IsZero() && !createdAtBefore.IsZero() { + qv.Check(createdAtAfter.Before(createdAtBefore), string(data.FilterKeyCreatedAtAfter), "created_at_after must be before created_at_before") + } + + if !createdAtAfter.IsZero() { + validFilters[data.FilterKeyCreatedAtAfter] = createdAtAfter + } + if !createdAtBefore.IsZero() { + validFilters[data.FilterKeyCreatedAtBefore] = createdAtBefore + } + return validFilters +} + +// validateAndGetDisbursementStatuses takes a comma-separated string of disbursement statuses +// and returns a slice of valid DisbursementStatus values. +func (qv *DisbursementQueryValidator) validateAndGetDisbursementStatuses(statuses string) []data.DisbursementStatus { + statusList := strings.Split(statuses, ",") + validStatuses := []data.DisbursementStatus{} + + for _, status := range statusList { + s := data.DisbursementStatus(strings.ToUpper(strings.TrimSpace(status))) + switch s { + case data.DraftDisbursementStatus, data.ReadyDisbursementStatus, data.StartedDisbursementStatus, data.PausedDisbursementStatus, data.CompletedDisbursementStatus: + validStatuses = append(validStatuses, s) + default: + qv.Check(false, string(data.FilterKeyStatus), "invalid parameter. valid value is a comma separate list of statuses: draft, ready, started, paused, completed") + } + } + return validStatuses +} diff --git a/internal/serve/validators/disbursement_query_validator_test.go b/internal/serve/validators/disbursement_query_validator_test.go new file mode 100644 index 000000000..ef9ef43e2 --- /dev/null +++ b/internal/serve/validators/disbursement_query_validator_test.go @@ -0,0 +1,118 @@ +package validators + +import ( + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" +) + +func Test_DisbursementQueryValidator_ValidateDisbursementFilters(t *testing.T) { + t.Run("Valid filters", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "2023-01-01", + data.FilterKeyCreatedAtBefore: "2023-01-31", + } + + actual := validator.ValidateAndGetDisbursementFilters(filters) + + assert.Equal(t, []data.DisbursementStatus{data.DraftDisbursementStatus}, actual[data.FilterKeyStatus]) + assert.Equal(t, time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), actual[data.FilterKeyCreatedAtAfter]) + assert.Equal(t, time.Date(2023, 1, 31, 0, 0, 0, 0, time.UTC), actual[data.FilterKeyCreatedAtBefore]) + }) + + t.Run("Invalid status", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "unknown", + } + + validator.ValidateAndGetDisbursementFilters(filters) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid value is a comma separate list of statuses: draft, ready, started, paused, completed", validator.Errors["status"]) + }) + + t.Run("Invalid date", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "00-01-31", + data.FilterKeyCreatedAtBefore: "00-01-01", + } + + validator.ValidateAndGetDisbursementFilters(filters) + + assert.Equal(t, 2, len(validator.Errors)) + assert.Equal(t, "invalid date format. valid format is 'YYYY-MM-DD'", validator.Errors["created_at_after"]) + assert.Equal(t, "invalid date format. valid format is 'YYYY-MM-DD'", validator.Errors["created_at_before"]) + }) + + t.Run("Invalid date range", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "2023-01-31", + data.FilterKeyCreatedAtBefore: "2023-01-01", + } + + validator.ValidateAndGetDisbursementFilters(filters) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "created_at_after must be before created_at_before", validator.Errors["created_at_after"]) + }) +} + +func Test_DisbursementQueryValidator_ValidateAndGetDisbursementStatuses(t *testing.T) { + t.Run("Valid status", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + validStatus := []data.DisbursementStatus{data.DraftDisbursementStatus, data.ReadyDisbursementStatus, data.StartedDisbursementStatus, data.PausedDisbursementStatus, data.CompletedDisbursementStatus} + for _, status := range validStatus { + assert.Equal(t, []data.DisbursementStatus{status}, validator.validateAndGetDisbursementStatuses(string(status))) + } + }) + + t.Run("Invalid status", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + invalidStatus := "unknown" + + actual := validator.validateAndGetDisbursementStatuses(invalidStatus) + assert.Empty(t, actual) + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid value is a comma separate list of statuses: draft, ready, started, paused, completed", validator.Errors["status"]) + }) + + t.Run("mix of valid and invalid statuses", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + statuses := "unknown1,unknown2,draft" + + actual := validator.validateAndGetDisbursementStatuses(statuses) + assert.Equal(t, 1, len(actual)) + assert.Equal(t, []data.DisbursementStatus{data.DraftDisbursementStatus}, actual) + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid value is a comma separate list of statuses: draft, ready, started, paused, completed", validator.Errors["status"]) + }) + + t.Run("valid comma separated list of statuses", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + statuses := "draft,ready,completed" + + actual := validator.validateAndGetDisbursementStatuses(statuses) + assert.Equal(t, 3, len(actual)) + assert.Equal(t, []data.DisbursementStatus{data.DraftDisbursementStatus, data.ReadyDisbursementStatus, data.CompletedDisbursementStatus}, actual) + assert.Equal(t, 0, len(validator.Errors)) + }) + + t.Run("valid comma separated list of statuses with spaces", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + statuses := " draft , ready , completed " + + actual := validator.validateAndGetDisbursementStatuses(statuses) + assert.Equal(t, 3, len(actual)) + assert.Equal(t, []data.DisbursementStatus{data.DraftDisbursementStatus, data.ReadyDisbursementStatus, data.CompletedDisbursementStatus}, actual) + assert.Equal(t, 0, len(validator.Errors)) + }) +} diff --git a/internal/serve/validators/mock.go b/internal/serve/validators/mock.go new file mode 100644 index 000000000..f3958deaf --- /dev/null +++ b/internal/serve/validators/mock.go @@ -0,0 +1,25 @@ +package validators + +import ( + "context" + "net/http" + + "github.com/stretchr/testify/mock" +) + +type ReCAPTCHAValidatorMock struct { + mock.Mock +} + +func (v *ReCAPTCHAValidatorMock) IsTokenValid(ctx context.Context, token string) (bool, error) { + args := v.Called(ctx, token) + return args.Bool(0), args.Error(1) +} + +type httpClientMock struct { + mockDo func(req *http.Request) (*http.Response, error) +} + +func (c *httpClientMock) Do(req *http.Request) (*http.Response, error) { + return c.mockDo(req) +} diff --git a/internal/serve/validators/payment_query_validator.go b/internal/serve/validators/payment_query_validator.go new file mode 100644 index 000000000..85252da2c --- /dev/null +++ b/internal/serve/validators/payment_query_validator.go @@ -0,0 +1,66 @@ +package validators + +import ( + "strings" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type PaymentQueryValidator struct { + QueryValidator +} + +// NewPaymentQueryValidator creates a new PaymentQueryValidator with the provided configuration. +func NewPaymentQueryValidator() *PaymentQueryValidator { + return &PaymentQueryValidator{ + QueryValidator: QueryValidator{ + DefaultSortField: data.DefaultPaymentSortField, + DefaultSortOrder: data.DefaultPaymentSortOrder, + AllowedSortFields: data.AllowedPaymentSorts, + AllowedFilters: data.AllowedPaymentFilters, + Validator: NewValidator(), + }, + } +} + +// ValidateAndGetPaymentFilters validates the filters and returns a map of valid filters. +func (qv *PaymentQueryValidator) ValidateAndGetPaymentFilters(filters map[data.FilterKey]interface{}) map[data.FilterKey]interface{} { + validFilters := make(map[data.FilterKey]interface{}) + if filters[data.FilterKeyStatus] != nil { + validFilters[data.FilterKeyStatus] = qv.validateAndGetPaymentStatus(filters[data.FilterKeyStatus].(string)) + } + if filters[data.FilterKeyReceiverID] != nil { + validFilters[data.FilterKeyReceiverID] = filters[data.FilterKeyReceiverID] + } + + createdAtAfter := qv.ValidateAndGetTimeParams(string(data.FilterKeyCreatedAtAfter), filters[data.FilterKeyCreatedAtAfter]) + createdAtBefore := qv.ValidateAndGetTimeParams(string(data.FilterKeyCreatedAtBefore), filters[data.FilterKeyCreatedAtBefore]) + + if qv.HasErrors() { + return validFilters + } + + if !createdAtAfter.IsZero() && !createdAtBefore.IsZero() { + qv.Check(createdAtAfter.Before(createdAtBefore), string(data.FilterKeyCreatedAtAfter), "created_at_after must be before created_at_before") + } + + if !createdAtAfter.IsZero() { + validFilters[data.FilterKeyCreatedAtAfter] = createdAtAfter + } + if !createdAtBefore.IsZero() { + validFilters[data.FilterKeyCreatedAtBefore] = createdAtBefore + } + return validFilters +} + +// validateAndGetPaymentStatus validates the status parameter and returns the corresponding PaymentStatus. +func (qv *PaymentQueryValidator) validateAndGetPaymentStatus(status string) data.PaymentStatus { + s := data.PaymentStatus(strings.ToUpper(status)) + switch s { + case data.DraftPaymentStatus, data.ReadyPaymentStatus, data.PendingPaymentStatus, data.PausedPaymentStatus, data.SuccessPaymentStatus, data.FailedPaymentStatus: + return s + default: + qv.Check(false, string(data.FilterKeyStatus), "invalid parameter. valid values are: draft, ready, pending, paused, success, failed") + return "" + } +} diff --git a/internal/serve/validators/payment_query_validator_test.go b/internal/serve/validators/payment_query_validator_test.go new file mode 100644 index 000000000..e8c3b5da0 --- /dev/null +++ b/internal/serve/validators/payment_query_validator_test.go @@ -0,0 +1,89 @@ +package validators + +import ( + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" +) + +func Test_PaymentQueryValidator_ValidateDisbursementFilters(t *testing.T) { + t.Run("Valid filters", func(t *testing.T) { + validator := NewPaymentQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyReceiverID: "receiver_id", + data.FilterKeyCreatedAtAfter: "2023-01-01", + data.FilterKeyCreatedAtBefore: "2023-01-31", + } + + actual := validator.ValidateAndGetPaymentFilters(filters) + + assert.Equal(t, data.DraftPaymentStatus, actual[data.FilterKeyStatus]) + assert.Equal(t, "receiver_id", actual[data.FilterKeyReceiverID]) + assert.Equal(t, time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), actual[data.FilterKeyCreatedAtAfter]) + assert.Equal(t, time.Date(2023, 1, 31, 0, 0, 0, 0, time.UTC), actual[data.FilterKeyCreatedAtBefore]) + }) + + t.Run("Invalid status", func(t *testing.T) { + validator := NewPaymentQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "unknown", + } + + validator.ValidateAndGetPaymentFilters(filters) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid values are: draft, ready, pending, paused, success, failed", validator.Errors["status"]) + }) + + t.Run("Invalid date", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "00-01-31", + data.FilterKeyCreatedAtBefore: "00-01-01", + } + + validator.ValidateAndGetDisbursementFilters(filters) + + assert.Equal(t, 2, len(validator.Errors)) + assert.Equal(t, "invalid date format. valid format is 'YYYY-MM-DD'", validator.Errors["created_at_after"]) + assert.Equal(t, "invalid date format. valid format is 'YYYY-MM-DD'", validator.Errors["created_at_before"]) + }) + + t.Run("Invalid date range", func(t *testing.T) { + validator := NewDisbursementQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "2023-01-31", + data.FilterKeyCreatedAtBefore: "2023-01-01", + } + + validator.ValidateAndGetDisbursementFilters(filters) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "created_at_after must be before created_at_before", validator.Errors["created_at_after"]) + }) +} + +func Test_PaymentQueryValidator_ValidateAndGetPaymentStatus(t *testing.T) { + t.Run("Valid status", func(t *testing.T) { + validator := NewPaymentQueryValidator() + validStatus := []data.PaymentStatus{data.DraftPaymentStatus, data.ReadyPaymentStatus, data.PendingPaymentStatus, data.PausedPaymentStatus, data.SuccessPaymentStatus, data.FailedPaymentStatus} + for _, status := range validStatus { + assert.Equal(t, status, validator.validateAndGetPaymentStatus(string(status))) + } + }) + + t.Run("Invalid status", func(t *testing.T) { + validator := NewPaymentQueryValidator() + invalidStatus := "unknown" + + actual := validator.validateAndGetPaymentStatus(invalidStatus) + assert.Empty(t, actual) + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid values are: draft, ready, pending, paused, success, failed", validator.Errors["status"]) + }) +} diff --git a/internal/serve/validators/query_validator.go b/internal/serve/validators/query_validator.go new file mode 100644 index 000000000..89b229201 --- /dev/null +++ b/internal/serve/validators/query_validator.go @@ -0,0 +1,98 @@ +package validators + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "golang.org/x/exp/slices" +) + +type QueryValidator struct { + *Validator + DefaultSortField data.SortField + DefaultSortOrder data.SortOrder + AllowedSortFields []data.SortField + AllowedFilters []data.FilterKey +} + +// ParseParametersFromRequest parses query parameters from the request and returns a QueryParams struct. +func (qv *QueryValidator) ParseParametersFromRequest(r *http.Request) *data.QueryParams { + page := qv.validateAndGetIntParams(r, "page", 1) + pageLimit := qv.validateAndGetIntParams(r, "page_limit", 20) + + query := r.URL.Query() + sortBy := data.SortField(query.Get("sort")) + if sortBy == "" { + sortBy = qv.DefaultSortField + } else if !slices.Contains(qv.AllowedSortFields, sortBy) { + qv.addError("sort", "invalid sort field name") + } + + sortOrder := data.SortOrder(strings.ToUpper(query.Get("direction"))) + if sortOrder == "" { + sortOrder = qv.DefaultSortOrder + } else if sortOrder != data.SortOrderASC && sortOrder != data.SortOrderDESC { + qv.addError("direction", "invalid sort order. valid values are 'asc' and 'desc'") + } + + filters := make(map[data.FilterKey]interface{}) + for _, fk := range qv.AllowedFilters { + value := strings.TrimSpace(query.Get(string(fk))) + if value != "" { + filters[fk] = value + } + } + + if qv.HasErrors() { + return &data.QueryParams{} + } + + return &data.QueryParams{ + Query: strings.TrimSpace(query.Get("q")), + Page: page, + PageLimit: pageLimit, + SortBy: sortBy, + SortOrder: sortOrder, + Filters: filters, + } +} + +// validateAndGetIntParams validates the query parameter and returns the value as an integer. +func (qv *QueryValidator) validateAndGetIntParams(r *http.Request, param string, defaultValue int) int { + value := r.URL.Query().Get(param) + if value == "" { + return defaultValue + } + + intValue, err := strconv.Atoi(value) + if err != nil { + qv.CheckError(err, param, "parameter must be an integer") + return defaultValue + } + + return intValue +} + +// ValidateAndGetTimeParams validates the query parameter and returns the value as a time.Time. +func (qv *QueryValidator) ValidateAndGetTimeParams(param string, value interface{}) time.Time { + if value == nil { + return time.Time{} + } + + dateStr, ok := value.(string) + if !ok { + qv.Check(false, param, "invalid date format. valid format is 'YYYY-MM-DD'") + return time.Time{} + } + + dateParam, err := time.Parse("2006-01-02", dateStr) + if err != nil { + qv.Check(false, param, "invalid date format. valid format is 'YYYY-MM-DD'") + return time.Time{} + } + + return dateParam +} diff --git a/internal/serve/validators/query_validator_test.go b/internal/serve/validators/query_validator_test.go new file mode 100644 index 000000000..fa7699412 --- /dev/null +++ b/internal/serve/validators/query_validator_test.go @@ -0,0 +1,176 @@ +package validators + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" +) + +func Test_QueryValidator_ParseQueryParameters(t *testing.T) { + tests := []struct { + name string + url string + defaultSortBy data.SortField + defaultSortOrder data.SortOrder + allowedSortFields []data.SortField + filterKeys []data.FilterKey + expectedParams *data.QueryParams + hasErrors bool + expectedErrors map[string]interface{} + }{ + { + name: "no query parameters - return default values", + url: "http://example.com/test", + defaultSortBy: data.SortFieldName, + defaultSortOrder: data.SortOrderASC, + expectedParams: &data.QueryParams{ + Query: "", + Page: 1, + PageLimit: 20, + SortBy: data.SortFieldName, + SortOrder: data.SortOrderASC, + Filters: map[data.FilterKey]interface{}{}, + }, + hasErrors: false, + expectedErrors: map[string]interface{}{}, + }, + { + name: "valid query parameters", + url: "http://example.com/test?q=hello&page=2&page_limit=10&sort=created_at&direction=desc&status=completed&created_at_after=2020-01-01&created_at_before=2020-01-02", + defaultSortBy: data.SortFieldName, + defaultSortOrder: data.SortOrderASC, + allowedSortFields: []data.SortField{ + data.SortFieldName, + data.SortFieldCreatedAt, + }, + filterKeys: []data.FilterKey{ + data.FilterKeyStatus, + data.FilterKeyCreatedAtAfter, + data.FilterKeyCreatedAtBefore, + }, + expectedParams: &data.QueryParams{ + Query: "hello", + Page: 2, + PageLimit: 10, + SortBy: data.SortFieldCreatedAt, + SortOrder: data.SortOrderDESC, + Filters: map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "completed", + data.FilterKeyCreatedAtAfter: "2020-01-01", + data.FilterKeyCreatedAtBefore: "2020-01-02", + }, + }, + hasErrors: false, + expectedErrors: map[string]interface{}{}, + }, + { + name: "invalid page value", + url: "http://example.com/test?page=abc", + expectedParams: &data.QueryParams{}, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "page": "parameter must be an integer", + }, + }, + { + name: "invalid page_limit value", + url: "http://example.com/test?page_limit=abc", + expectedParams: &data.QueryParams{}, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "page_limit": "parameter must be an integer", + }, + }, + { + name: "invalid sort field", + url: "http://example.com/test?sort=abc", + expectedParams: &data.QueryParams{}, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "sort": "invalid sort field name", + }, + }, + { + name: "invalid sort order", + url: "http://example.com/test?direction=abc", + expectedParams: &data.QueryParams{}, + hasErrors: true, + expectedErrors: map[string]interface{}{ + "direction": "invalid sort order. valid values are 'asc' and 'desc'", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tt.url, nil) + v := &DisbursementQueryValidator{ + QueryValidator: QueryValidator{ + DefaultSortField: tt.defaultSortBy, + DefaultSortOrder: tt.defaultSortOrder, + AllowedSortFields: tt.allowedSortFields, + AllowedFilters: tt.filterKeys, + Validator: NewValidator(), + }, + } + params := v.ParseParametersFromRequest(req) + + assert.Equal(t, tt.expectedParams, params) + assert.Equal(t, tt.hasErrors, v.HasErrors()) + assert.Equal(t, tt.expectedErrors, v.Errors) + }) + } +} + +func Test_QueryValidator_ValidateAndGetIntParams(t *testing.T) { + tests := []struct { + name string + param string + url string + defaultValue int + expected int + hasError bool + }{ + { + name: "no parameter", + param: "limit", + url: "http://example.com/test", + defaultValue: 10, + expected: 10, + hasError: false, + }, + { + name: "valid parameter", + param: "limit", + url: "http://example.com/test?limit=5", + defaultValue: 10, + expected: 5, + hasError: false, + }, + { + name: "invalid parameter", + param: "limit", + url: "http://example.com/test?limit=abc", + defaultValue: 10, + expected: 10, + hasError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tt.url, nil) + qv := QueryValidator{ + Validator: NewValidator(), + } + + actual := qv.validateAndGetIntParams(req, tt.param, tt.defaultValue) + + assert.Equal(t, tt.expected, actual) + assert.Equal(t, tt.hasError, qv.HasErrors()) + }) + } +} diff --git a/internal/serve/validators/recaptcha.go b/internal/serve/validators/recaptcha.go new file mode 100644 index 000000000..298dd1970 --- /dev/null +++ b/internal/serve/validators/recaptcha.go @@ -0,0 +1,87 @@ +package validators + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +const ( + // timeoutOrDuplicateErrorCode is not a configuration error. + // Reference: https://developers.google.com/recaptcha/docs/verify#error_code_reference + timeoutOrDuplicateErrorCode = "timeout-or-duplicate" + + // verifyTokenURL is the URL used to verify if the token generated by captcha is valid. + verifyTokenURL = "https://www.google.com/recaptcha/api/siteverify" +) + +type ReCAPTCHAValidator interface { + IsTokenValid(ctx context.Context, token string) (bool, error) +} + +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type GoogleReCAPTCHAValidator struct { + SiteSecretKey string + VerifyTokenURL string + BaseURL string + HTTPClient HTTPClient +} + +type verifyTokenResponse struct { + Success bool `json:"success"` + ErrorCodes []string `json:"error-codes"` +} + +func (v *GoogleReCAPTCHAValidator) IsTokenValid(ctx context.Context, token string) (bool, error) { + payload := fmt.Sprintf("secret=%s&response=%s", v.SiteSecretKey, token) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.VerifyTokenURL, strings.NewReader(payload)) + if err != nil { + return false, fmt.Errorf("error creating request: %w", err) + } + + // The request doesn't work with application/json MIME type + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err := v.HTTPClient.Do(req) + if err != nil { + return false, fmt.Errorf("error requesting verify reCAPTCHA token: %w", err) + } + defer resp.Body.Close() + + respBodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return false, fmt.Errorf("error reading body response: %w", err) + } + + var respBody verifyTokenResponse + if err := json.Unmarshal(respBodyBytes, &respBody); err != nil { + return false, fmt.Errorf("error unmarshalling body response: %w", err) + } + + for _, errorCode := range respBody.ErrorCodes { + if errorCode == timeoutOrDuplicateErrorCode { + return false, nil + } + } + + if len(respBody.ErrorCodes) > 0 { + return false, fmt.Errorf("error returned by verify reCAPTCHA token: %v", respBody.ErrorCodes) + } + + return respBody.Success, nil +} + +func NewGoogleReCAPTCHAValidator(siteSecretKey string, httpClient HTTPClient) *GoogleReCAPTCHAValidator { + return &GoogleReCAPTCHAValidator{ + SiteSecretKey: siteSecretKey, + VerifyTokenURL: verifyTokenURL, + HTTPClient: httpClient, + } +} diff --git a/internal/serve/validators/recaptcha_test.go b/internal/serve/validators/recaptcha_test.go new file mode 100644 index 000000000..4e021e979 --- /dev/null +++ b/internal/serve/validators/recaptcha_test.go @@ -0,0 +1,172 @@ +package validators + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GoogleReCAPTCHAValidator(t *testing.T) { + siteSecretKey := "secretKey" + httpClientMock := &httpClientMock{} + + grv := NewGoogleReCAPTCHAValidator(siteSecretKey, httpClientMock) + + ctx := context.Background() + t.Run("returns error when requesting verify token URL fails", func(t *testing.T) { + token := "token" + + httpClientMock.mockDo = func(req *http.Request) (*http.Response, error) { + assert.Equal(t, verifyTokenURL, req.URL.String()) + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + + reqBody, err := io.ReadAll(req.Body) + require.NoError(t, err) + defer req.Body.Close() + + assert.Equal(t, fmt.Sprintf(`secret=%s&response=%s`, siteSecretKey, token), string(reqBody)) + + return &http.Response{ + Body: io.NopCloser(strings.NewReader("{}")), + StatusCode: http.StatusOK, + }, fmt.Errorf("unexpected error") + } + + isValid, err := grv.IsTokenValid(ctx, token) + + assert.False(t, isValid) + assert.EqualError(t, err, "error requesting verify reCAPTCHA token: unexpected error") + }) + + t.Run("returns error when an error code is returned", func(t *testing.T) { + token := "token" + + httpClientMock.mockDo = func(req *http.Request) (*http.Response, error) { + assert.Equal(t, verifyTokenURL, req.URL.String()) + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + + reqBody, err := io.ReadAll(req.Body) + require.NoError(t, err) + defer req.Body.Close() + + assert.Equal(t, fmt.Sprintf(`secret=%s&response=%s`, siteSecretKey, token), string(reqBody)) + + respBody := ` + { + "success": false, + "error-codes": [ + "bad-request" + ] + } + ` + + return &http.Response{ + Body: io.NopCloser(strings.NewReader(respBody)), + StatusCode: http.StatusOK, + }, nil + } + + isValid, err := grv.IsTokenValid(ctx, token) + + assert.False(t, isValid) + assert.EqualError(t, err, "error returned by verify reCAPTCHA token: [bad-request]") + }) + + t.Run("returns false when timeout-or-duplicate error code is returned", func(t *testing.T) { + token := "token" + + httpClientMock.mockDo = func(req *http.Request) (*http.Response, error) { + assert.Equal(t, verifyTokenURL, req.URL.String()) + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + + reqBody, err := io.ReadAll(req.Body) + require.NoError(t, err) + defer req.Body.Close() + + assert.Equal(t, fmt.Sprintf(`secret=%s&response=%s`, siteSecretKey, token), string(reqBody)) + + respBody := ` + { + "success": false, + "error-codes": [ + "bad-request", + "timeout-or-duplicate" + ] + } + ` + + return &http.Response{ + Body: io.NopCloser(strings.NewReader(respBody)), + StatusCode: http.StatusOK, + }, nil + } + + isValid, err := grv.IsTokenValid(ctx, token) + + assert.False(t, isValid) + assert.NoError(t, err) + }) + + t.Run("returns whether the token is invalid or not", func(t *testing.T) { + token := "token" + + // Token invalid + httpClientMock.mockDo = func(req *http.Request) (*http.Response, error) { + assert.Equal(t, verifyTokenURL, req.URL.String()) + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + + reqBody, err := io.ReadAll(req.Body) + require.NoError(t, err) + defer req.Body.Close() + + assert.Equal(t, fmt.Sprintf(`secret=%s&response=%s`, siteSecretKey, token), string(reqBody)) + + respBody := `{"success": false}` + + return &http.Response{ + Body: io.NopCloser(strings.NewReader(respBody)), + StatusCode: http.StatusOK, + }, nil + } + + isValid, err := grv.IsTokenValid(ctx, token) + + assert.False(t, isValid) + assert.NoError(t, err) + + // Token is valid + httpClientMock.mockDo = func(req *http.Request) (*http.Response, error) { + assert.Equal(t, verifyTokenURL, req.URL.String()) + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "application/x-www-form-urlencoded", req.Header.Get("Content-Type")) + + reqBody, rErr := io.ReadAll(req.Body) + require.NoError(t, rErr) + defer req.Body.Close() + + assert.Equal(t, fmt.Sprintf(`secret=%s&response=%s`, siteSecretKey, token), string(reqBody)) + + respBody := `{"success": true}` + + return &http.Response{ + Body: io.NopCloser(strings.NewReader(respBody)), + StatusCode: http.StatusOK, + }, nil + } + + isValid, err = grv.IsTokenValid(ctx, token) + + assert.True(t, isValid) + assert.NoError(t, err) + }) +} diff --git a/internal/serve/validators/receiver_query_validator.go b/internal/serve/validators/receiver_query_validator.go new file mode 100644 index 000000000..021bc5817 --- /dev/null +++ b/internal/serve/validators/receiver_query_validator.go @@ -0,0 +1,63 @@ +package validators + +import ( + "strings" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type ReceiverQueryValidator struct { + QueryValidator +} + +// NewReceiverQueryValidator creates a new ReceiverQueryValidator with the provided configuration. +func NewReceiverQueryValidator() *ReceiverQueryValidator { + return &ReceiverQueryValidator{ + QueryValidator: QueryValidator{ + DefaultSortField: data.DefaultReceiverSortField, + DefaultSortOrder: data.DefaultReceiverSortOrder, + AllowedSortFields: data.AllowedReceiverSorts, + AllowedFilters: data.AllowedReceiverFilters, + Validator: NewValidator(), + }, + } +} + +// ValidateAndGetReceiverFilters validates the filters and returns a map of valid filters. +func (qv *ReceiverQueryValidator) ValidateAndGetReceiverFilters(filters map[data.FilterKey]interface{}) map[data.FilterKey]interface{} { + validFilters := make(map[data.FilterKey]interface{}) + if filters[data.FilterKeyStatus] != nil { + validFilters[data.FilterKeyStatus] = qv.validateAndGetReceiverWalletStatus(filters[data.FilterKeyStatus].(string)) + } + + createdAtAfter := qv.ValidateAndGetTimeParams(string(data.FilterKeyCreatedAtAfter), filters[data.FilterKeyCreatedAtAfter]) + createdAtBefore := qv.ValidateAndGetTimeParams(string(data.FilterKeyCreatedAtBefore), filters[data.FilterKeyCreatedAtBefore]) + + if qv.HasErrors() { + return validFilters + } + + if !createdAtAfter.IsZero() && !createdAtBefore.IsZero() { + qv.Check(createdAtAfter.Before(createdAtBefore), string(data.FilterKeyCreatedAtAfter), "created_at_after must be before created_at_before") + } + + if !createdAtAfter.IsZero() { + validFilters[data.FilterKeyCreatedAtAfter] = createdAtAfter + } + if !createdAtBefore.IsZero() { + validFilters[data.FilterKeyCreatedAtBefore] = createdAtBefore + } + return validFilters +} + +// validateAndGetReceiverWalletStatus validates the status parameter and returns the corresponding ReceiverWalletStatus. +func (qv *ReceiverQueryValidator) validateAndGetReceiverWalletStatus(status string) data.ReceiversWalletStatus { + s := data.ReceiversWalletStatus(strings.ToUpper(status)) + switch s { + case data.DraftReceiversWalletStatus, data.ReadyReceiversWalletStatus, data.RegisteredReceiversWalletStatus, data.FlaggedReceiversWalletStatus: + return s + default: + qv.Check(false, string(data.FilterKeyStatus), "invalid parameter. valid values are: draft, ready, registered, flagged") + return "" + } +} diff --git a/internal/serve/validators/receiver_query_validator_test.go b/internal/serve/validators/receiver_query_validator_test.go new file mode 100644 index 000000000..07f09ef20 --- /dev/null +++ b/internal/serve/validators/receiver_query_validator_test.go @@ -0,0 +1,92 @@ +package validators + +import ( + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" +) + +func Test_ReceiverQueryValidator_ValidateReceiverFilters(t *testing.T) { + t.Run("Valid filters", func(t *testing.T) { + validator := NewReceiverQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "2023-01-01", + data.FilterKeyCreatedAtBefore: "2023-01-31", + } + + actual := validator.ValidateAndGetReceiverFilters(filters) + + assert.Equal(t, data.DraftReceiversWalletStatus, actual[data.FilterKeyStatus]) + assert.Equal(t, time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), actual[data.FilterKeyCreatedAtAfter]) + assert.Equal(t, time.Date(2023, 1, 31, 0, 0, 0, 0, time.UTC), actual[data.FilterKeyCreatedAtBefore]) + }) + + t.Run("Invalid status", func(t *testing.T) { + validator := NewReceiverQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "unknown", + } + + validator.ValidateAndGetReceiverFilters(filters) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid values are: draft, ready, registered, flagged", validator.Errors["status"]) + }) + + t.Run("Invalid date", func(t *testing.T) { + validator := NewReceiverQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "00-01-31", + data.FilterKeyCreatedAtBefore: "00-01-01", + } + + validator.ValidateAndGetReceiverFilters(filters) + + assert.Equal(t, 2, len(validator.Errors)) + assert.Equal(t, "invalid date format. valid format is 'YYYY-MM-DD'", validator.Errors["created_at_after"]) + assert.Equal(t, "invalid date format. valid format is 'YYYY-MM-DD'", validator.Errors["created_at_before"]) + }) + + t.Run("Invalid date range", func(t *testing.T) { + validator := NewReceiverQueryValidator() + filters := map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "draft", + data.FilterKeyCreatedAtAfter: "2023-01-31", + data.FilterKeyCreatedAtBefore: "2023-01-01", + } + + validator.ValidateAndGetReceiverFilters(filters) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "created_at_after must be before created_at_before", validator.Errors["created_at_after"]) + }) +} + +func Test_ReceiverQueryValidator_ValidateAndGetReceiverStatus(t *testing.T) { + t.Run("Valid status", func(t *testing.T) { + validator := NewReceiverQueryValidator() + validStatus := []data.ReceiversWalletStatus{ + data.DraftReceiversWalletStatus, + data.ReadyReceiversWalletStatus, + data.ReadyReceiversWalletStatus, + data.FlaggedReceiversWalletStatus, + } + for _, status := range validStatus { + assert.Equal(t, status, validator.validateAndGetReceiverWalletStatus(string(status))) + } + }) + + t.Run("Invalid status", func(t *testing.T) { + validator := NewReceiverQueryValidator() + invalidStatus := "unknown" + + actual := validator.validateAndGetReceiverWalletStatus(invalidStatus) + assert.Empty(t, actual) + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid values are: draft, ready, registered, flagged", validator.Errors["status"]) + }) +} diff --git a/internal/serve/validators/receiver_registration_validator.go b/internal/serve/validators/receiver_registration_validator.go new file mode 100644 index 000000000..b7b2b5e94 --- /dev/null +++ b/internal/serve/validators/receiver_registration_validator.go @@ -0,0 +1,69 @@ +package validators + +import ( + "strings" + "time" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" +) + +type ReceiverRegistrationValidator struct { + *Validator +} + +// NewReceiverRegistrationValidator creates a new ReceiverRegistrationValidator with the provided configuration. +func NewReceiverRegistrationValidator() *ReceiverRegistrationValidator { + return &ReceiverRegistrationValidator{ + Validator: NewValidator(), + } +} + +// ValidateReceiver validates if the infos present in the ReceiverRegistrationRequest are valids. +func (rv *ReceiverRegistrationValidator) ValidateReceiver(receiverInfo *data.ReceiverRegistrationRequest) { + phone := strings.TrimSpace(receiverInfo.PhoneNumber) + otp := strings.TrimSpace(receiverInfo.OTP) + verification := strings.TrimSpace(receiverInfo.VerificationValue) + verificationType := strings.TrimSpace(string(receiverInfo.VerificationType)) + + // validate phone field + rv.CheckError(utils.ValidatePhoneNumber(phone), "phone_number", "invalid phone format. Correct format: +380445555555") + rv.Check(strings.TrimSpace(phone) != "", "phone_number", "phone cannot be empty") + + // validate otp field + rv.CheckError(utils.ValidateOTP(otp), "otp", "invalid otp format. Needs to be a 6 digit value") + + // validate verification type field + rv.Check(verificationType != "", "verification_type", "verification type cannot be empty") + vt := rv.validateAndGetVerificationType(verificationType) + + // validate verification field + // date of birth with format 2006-01-02 + if vt == data.VerificationFieldDateOfBirth { + _, err := time.Parse("2006-01-02", verification) + rv.CheckError(err, "verification", "invalid date of birth format. Correct format: 1990-01-01") + } else { + // TODO: validate other VerificationField types. + log.Warnf("Verification type %v is not being validated for ValidateReceiver", vt) + } + + receiverInfo.PhoneNumber = phone + receiverInfo.OTP = otp + receiverInfo.VerificationValue = verification + receiverInfo.VerificationType = vt +} + +// validateAndGetVerificationType validates if the verification type field is a valid value. +func (rv *ReceiverRegistrationValidator) validateAndGetVerificationType(verificationType string) data.VerificationField { + vt := data.VerificationField(strings.ToUpper(verificationType)) + + switch vt { + case data.VerificationFieldDateOfBirth, data.VerificationFieldPin, data.VerificationFieldNationalID: + return vt + default: + rv.Check(false, "verification_type", "invalid parameter. valid values are: DATE_OF_BIRTH, PIN, NATIONAL_ID_NUMBER") + return "" + } +} diff --git a/internal/serve/validators/receiver_registration_validator_test.go b/internal/serve/validators/receiver_registration_validator_test.go new file mode 100644 index 000000000..2cb6257d3 --- /dev/null +++ b/internal/serve/validators/receiver_registration_validator_test.go @@ -0,0 +1,127 @@ +package validators + +import ( + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" +) + +func Test_ReceiverRegistrationValidator_ValidateReceiver(t *testing.T) { + t.Run("Invalid phone number", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + + receiverInfo := data.ReceiverRegistrationRequest{ + PhoneNumber: "invalid", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "DATE_OF_BIRTH", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid phone format. Correct format: +380445555555", validator.Errors["phone_number"]) + }) + + t.Run("Empty phone number", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + + receiverInfo := data.ReceiverRegistrationRequest{ + PhoneNumber: "", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "DATE_OF_BIRTH", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "phone cannot be empty", validator.Errors["phone_number"]) + }) + + t.Run("Invalid otp", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + + receiverInfo := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "12mock", + VerificationValue: "1990-01-01", + VerificationType: "DATE_OF_BIRTH", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid otp format. Needs to be a 6 digit value", validator.Errors["otp"]) + }) + + t.Run("Invalid verification type", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + + receiverInfo := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "1990-01-01", + VerificationType: "mock_type", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid values are: DATE_OF_BIRTH, PIN, NATIONAL_ID_NUMBER", validator.Errors["verification_type"]) + }) + + t.Run("Invalid date of birth", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + + receiverInfo := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555", + OTP: "123456", + VerificationValue: "90/01/01", + VerificationType: "DATE_OF_BIRTH", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid date of birth format. Correct format: 1990-01-01", validator.Errors["verification"]) + }) + + t.Run("Valid receiver values", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + + receiverInfo := data.ReceiverRegistrationRequest{ + PhoneNumber: "+380445555555 ", + OTP: " 123456 ", + VerificationValue: "1990-01-01 ", + VerificationType: "date_of_birth", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 0, len(validator.Errors)) + assert.Equal(t, "+380445555555", receiverInfo.PhoneNumber) + assert.Equal(t, "123456", receiverInfo.OTP) + assert.Equal(t, "1990-01-01", receiverInfo.VerificationValue) + assert.Equal(t, data.VerificationField("DATE_OF_BIRTH"), receiverInfo.VerificationType) + }) +} + +func Test_ReceiverRegistrationValidator_ValidateAndGetVerificationType(t *testing.T) { + t.Run("Valid verification type", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + validField := []data.VerificationField{ + data.VerificationFieldDateOfBirth, + data.VerificationFieldPin, + data.VerificationFieldNationalID, + } + for _, field := range validField { + assert.Equal(t, field, validator.validateAndGetVerificationType(string(field))) + } + }) + + t.Run("Invalid verification type", func(t *testing.T) { + validator := NewReceiverRegistrationValidator() + invalidStatus := "unknown" + + actual := validator.validateAndGetVerificationType(invalidStatus) + assert.Empty(t, actual) + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid parameter. valid values are: DATE_OF_BIRTH, PIN, NATIONAL_ID_NUMBER", validator.Errors["verification_type"]) + }) +} diff --git a/internal/serve/validators/receiver_update_validator.go b/internal/serve/validators/receiver_update_validator.go new file mode 100644 index 000000000..1213b1f6c --- /dev/null +++ b/internal/serve/validators/receiver_update_validator.go @@ -0,0 +1,70 @@ +package validators + +import ( + "strings" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type UpdateReceiverRequest struct { + DateOfBirth string `json:"date_of_birth"` + Pin string `json:"pin"` + NationalID string `json:"national_id"` + Email string `json:"email"` + ExternalID string `json:"external_id"` +} +type UpdateReceiverValidator struct { + *Validator +} + +// NewReceiverRegistrationValidator creates a new ReceiverRegistrationValidator with the provided configuration. +func NewUpdateReceiverValidator() *UpdateReceiverValidator { + return &UpdateReceiverValidator{ + Validator: NewValidator(), + } +} + +// ValidateReceiver validates if the infos present in the ReceiverRegistrationRequest are valids. +func (ur *UpdateReceiverValidator) ValidateReceiver(updateReceiverRequest *UpdateReceiverRequest) { + ur.Check(*updateReceiverRequest != UpdateReceiverRequest{}, "body", "request body is empty") + + if ur.HasErrors() { + return + } + + dateOfBirth := strings.TrimSpace(updateReceiverRequest.DateOfBirth) + pin := strings.TrimSpace(updateReceiverRequest.Pin) + nationalID := strings.TrimSpace(updateReceiverRequest.NationalID) + email := strings.TrimSpace(updateReceiverRequest.Email) + externalID := strings.TrimSpace(updateReceiverRequest.ExternalID) + + if dateOfBirth != "" { + _, err := time.Parse("2006-01-02", updateReceiverRequest.DateOfBirth) + ur.CheckError(err, "date_of_birth", "invalid date of birth format. Correct format: 1990-01-30") + } + + if updateReceiverRequest.Pin != "" { + // TODO: add new validation to PIN type. + ur.Check(pin != "", "pin", "invalid pin format") + } + + if updateReceiverRequest.NationalID != "" { + // TODO: add new validation to NationalID type. + ur.Check(nationalID != "", "national_id", "invalid national ID format") + } + + if updateReceiverRequest.Email != "" { + ur.Check(utils.ValidateEmail(email) == nil, "email", "invalid email format") + } + + if updateReceiverRequest.ExternalID != "" { + ur.Check(externalID != "", "external_id", "invalid external_id format") + } + + updateReceiverRequest.DateOfBirth = dateOfBirth + updateReceiverRequest.Pin = pin + updateReceiverRequest.NationalID = nationalID + updateReceiverRequest.Email = email + updateReceiverRequest.ExternalID = externalID +} diff --git a/internal/serve/validators/receiver_update_validator_test.go b/internal/serve/validators/receiver_update_validator_test.go new file mode 100644 index 000000000..b04d0de98 --- /dev/null +++ b/internal/serve/validators/receiver_update_validator_test.go @@ -0,0 +1,105 @@ +package validators + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_UpdateReceiverValidator_ValidateReceiver(t *testing.T) { + t.Run("Empty request", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{} + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "request body is empty", validator.Errors["body"]) + }) + + t.Run("Invalid date of birth", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{ + DateOfBirth: "invalid", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid date of birth format. Correct format: 1990-01-30", validator.Errors["date_of_birth"]) + }) + + t.Run("Invalid pin", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{ + Pin: " ", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid pin format", validator.Errors["pin"]) + }) + + t.Run("Invalid national ID", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{ + NationalID: " ", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid national ID format", validator.Errors["national_id"]) + }) + + t.Run("invalid email", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{ + Email: "invalid", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid email format", validator.Errors["email"]) + + receiverInfo = UpdateReceiverRequest{ + Email: " ", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid email format", validator.Errors["email"]) + }) + + t.Run("invalid external ID", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{ + ExternalID: " ", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 1, len(validator.Errors)) + assert.Equal(t, "invalid external_id format", validator.Errors["external_id"]) + }) + + t.Run("Valid receiver values", func(t *testing.T) { + validator := NewUpdateReceiverValidator() + + receiverInfo := UpdateReceiverRequest{ + DateOfBirth: "1999-01-01", + Pin: "123 ", + NationalID: " 12345CODE", + Email: "receiver@email.com", + ExternalID: "externalID", + } + validator.ValidateReceiver(&receiverInfo) + + assert.Equal(t, 0, len(validator.Errors)) + assert.Equal(t, "1999-01-01", receiverInfo.DateOfBirth) + assert.Equal(t, "123", receiverInfo.Pin) + assert.Equal(t, "12345CODE", receiverInfo.NationalID) + }) +} diff --git a/internal/serve/validators/validator.go b/internal/serve/validators/validator.go new file mode 100644 index 000000000..d4282fee6 --- /dev/null +++ b/internal/serve/validators/validator.go @@ -0,0 +1,28 @@ +package validators + +type Validator struct { + Errors map[string]interface{} +} + +func NewValidator() *Validator { + return &Validator{Errors: make(map[string]interface{})} +} + +func (v *Validator) HasErrors() bool { + return len(v.Errors) > 0 +} + +func (v *Validator) Check(ok bool, key, message string) { + if !ok { + v.addError(key, message) + } +} + +// CheckError is a convenience method for checking if an error is nil +func (v *Validator) CheckError(err error, key, message string) { + v.Check(err == nil, key, message) +} + +func (v *Validator) addError(key, message string) { + v.Errors[key] = message +} diff --git a/internal/serve/validators/validator_test.go b/internal/serve/validators/validator_test.go new file mode 100644 index 000000000..537e5efb2 --- /dev/null +++ b/internal/serve/validators/validator_test.go @@ -0,0 +1,41 @@ +package validators + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewValidator(t *testing.T) { + validator := NewValidator() + assert.NotNil(t, validator) + assert.NotNil(t, validator.Errors) +} + +func Test_Check(t *testing.T) { + validator := NewValidator() + validator.Check(true, "key", "error message") + + assert.Emptyf(t, validator.Errors, "validator should not have errors") + + validator.Check(false, "key", "error message") + assert.NotEmpty(t, validator.Errors) + assert.Equal(t, validator.Errors["key"], "error message") +} + +func Test_HasErrors(t *testing.T) { + validator := NewValidator() + assert.False(t, validator.HasErrors()) + + validator.Check(false, "key", "error message") + assert.True(t, validator.HasErrors()) +} + +func Test_addError(t *testing.T) { + validator := NewValidator() + validator.addError("key", "error message") + validator.addError("key2", "error message 2") + assert.Equal(t, len(validator.Errors), 2) + assert.Equal(t, validator.Errors["key"], "error message") + assert.Equal(t, validator.Errors["key2"], "error message 2") +} diff --git a/internal/services/disbursement_service.go b/internal/services/disbursement_service.go new file mode 100644 index 000000000..ac160167e --- /dev/null +++ b/internal/services/disbursement_service.go @@ -0,0 +1,195 @@ +package services + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type DisbursementService struct { + models *data.Models + dbConnectionPool db.DBConnectionPool + authManager auth.AuthManager +} + +var ( + ErrDisbursementNotFound = errors.New("disbursement not found") + ErrDisbursementNotReadyToStart = errors.New("disbursement is not ready to be started") + ErrDisbursementNotReadyToPause = errors.New("disbursement is not ready to be paused") + + ErrDisbursementStatusCantBeChanged = errors.New("disbursement status can't be changed to the requested status") + ErrDisbursementStartedByCreator = errors.New("disbursement can't be started by its creator") +) + +// NewDisbursementService creates a new DisbursementService +func NewDisbursementService(models *data.Models, dbConnectionPool db.DBConnectionPool, authManager auth.AuthManager) *DisbursementService { + return &DisbursementService{ + models: models, + dbConnectionPool: dbConnectionPool, + authManager: authManager, + } +} + +func (s *DisbursementService) GetDisbursementsWithCount(ctx context.Context, queryParams *data.QueryParams) (*ResultWithTotal, error) { + return db.RunInTransactionWithResult(ctx, + s.dbConnectionPool, + &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: true}, + func(dbTx db.DBTransaction) (*ResultWithTotal, error) { + totalDisbursements, err := s.models.Disbursements.Count(ctx, dbTx, queryParams) + if err != nil { + return nil, fmt.Errorf("error counting disbursements: %w", err) + } + + var disbursements []*data.Disbursement + if totalDisbursements != 0 { + disbursements, err = s.models.Disbursements.GetAll(ctx, dbTx, queryParams) + if err != nil { + return nil, fmt.Errorf("error retrieving disbursements: %w", err) + } + } + + return NewResultWithTotal(totalDisbursements, disbursements), nil + }) +} + +func (s *DisbursementService) GetDisbursementReceiversWithCount(ctx context.Context, disbursementID string, queryParams *data.QueryParams) (*ResultWithTotal, error) { + return db.RunInTransactionWithResult(ctx, + s.dbConnectionPool, + &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: true}, + func(dbTx db.DBTransaction) (*ResultWithTotal, error) { + _, err := s.models.Disbursements.Get(ctx, dbTx, disbursementID) + if err != nil { + if errors.Is(err, data.ErrRecordNotFound) { + return nil, ErrDisbursementNotFound + } else { + return nil, fmt.Errorf("error getting disbursement with id %s: %w", disbursementID, err) + } + } + + totalReceivers, err := s.models.DisbursementReceivers.Count(ctx, dbTx, disbursementID) + if err != nil { + return nil, fmt.Errorf("error counting disbursement receivers for disbursement with id %s: %w", disbursementID, err) + } + + receivers := []*data.DisbursementReceiver{} + if totalReceivers != 0 { + receivers, err = s.models.DisbursementReceivers.GetAll(ctx, dbTx, queryParams, disbursementID) + if err != nil { + return nil, fmt.Errorf("error retrieving disbursement receivers for disbursement with id %s: %w", disbursementID, err) + } + } + + return NewResultWithTotal(totalReceivers, receivers), nil + }) +} + +// StartDisbursement starts a disbursement and all its payments and receivers wallets +func (s *DisbursementService) StartDisbursement(ctx context.Context, disbursementID string) error { + return db.RunInTransaction(ctx, s.dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + disbursement, err := s.models.Disbursements.Get(ctx, dbTx, disbursementID) + if err != nil { + if errors.Is(err, data.ErrRecordNotFound) { + return ErrDisbursementNotFound + } else { + return fmt.Errorf("error getting disbursement with id %s: %w", disbursementID, err) + } + } + + // 1. Verify Transition is Possible + err = disbursement.Status.TransitionTo(data.StartedDisbursementStatus) + if err != nil { + return ErrDisbursementNotReadyToStart + } + + // 2. Check if approval Workflow is enabled for this organization + organization, err := s.models.Organizations.Get(ctx) + if err != nil { + return fmt.Errorf("error getting organization: %w", err) + } + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + return fmt.Errorf("error getting token from context") + } + user, err := s.authManager.GetUser(ctx, token) + if err != nil { + return fmt.Errorf("error getting user from token: %w", err) + } + if organization.IsApprovalRequired { + // check that the user starting the disbursement isn't the same as the one who created it + for _, sh := range disbursement.StatusHistory { + if sh.UserID == user.ID && (sh.Status == data.DraftDisbursementStatus || sh.Status == data.ReadyDisbursementStatus) { + return ErrDisbursementStartedByCreator + } + } + } + + // 3. Update all correct payment status to `ready` + err = s.models.Payment.UpdateStatusByDisbursementID(ctx, dbTx, disbursementID, data.ReadyPaymentStatus) + if err != nil { + return fmt.Errorf("error updating payment status to ready for disbursement with id %s: %w", disbursementID, err) + } + + // 4. Update all receiver_wallets from `draft` to `ready` + err = s.models.ReceiverWallet.UpdateStatusByDisbursementID(ctx, dbTx, disbursementID, data.DraftReceiversWalletStatus, data.ReadyReceiversWalletStatus) + if err != nil { + return fmt.Errorf("error updating receiver wallet status to ready for disbursement with id %s: %w", disbursementID, err) + } + + // 5. Update disbursement status to `started` + err = s.models.Disbursements.UpdateStatus(ctx, dbTx, user.ID, disbursementID, data.StartedDisbursementStatus) + if err != nil { + return fmt.Errorf("error updating disbursement status to started for disbursement with id %s: %w", disbursementID, err) + } + + return nil + }) +} + +// PauseDisbursement pauses a disbursement and all its payments +func (s *DisbursementService) PauseDisbursement(ctx context.Context, disbursementID string) error { + return db.RunInTransaction(ctx, s.dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + disbursement, err := s.models.Disbursements.Get(ctx, dbTx, disbursementID) + if err != nil { + if errors.Is(err, data.ErrRecordNotFound) { + return ErrDisbursementNotFound + } else { + return fmt.Errorf("error getting disbursement with id %s: %w", disbursementID, err) + } + } + + // 1. Verify Transition is Possible + err = disbursement.Status.TransitionTo(data.PausedDisbursementStatus) + if err != nil { + return ErrDisbursementNotReadyToPause + } + + // 2. Update all correct payment status to `paused` + err = s.models.Payment.UpdateStatusByDisbursementID(ctx, dbTx, disbursementID, data.PausedPaymentStatus) + if err != nil { + return fmt.Errorf("error updating payment status to paused for disbursement with id %s: %w", disbursementID, err) + } + + // 3. Update disbursement status to `paused` + token, ok := ctx.Value(middleware.TokenContextKey).(string) + if !ok { + return fmt.Errorf("error getting token from context") + } + user, err := s.authManager.GetUser(ctx, token) + if err != nil { + return fmt.Errorf("error getting user from token: %w", err) + } + err = s.models.Disbursements.UpdateStatus(ctx, dbTx, user.ID, disbursementID, data.PausedDisbursementStatus) + if err != nil { + return fmt.Errorf("error updating disbursement status to started for disbursement with id %s: %w", disbursementID, err) + } + + return nil + }) +} diff --git a/internal/services/disbursement_service_test.go b/internal/services/disbursement_service_test.go new file mode 100644 index 000000000..658fb6ed6 --- /dev/null +++ b/internal/services/disbursement_service_test.go @@ -0,0 +1,555 @@ +package services + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/middleware" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/require" +) + +func Test_DisbursementService_GetDisbursementsWithCount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + service := NewDisbursementService(models, models.DBConnectionPool, nil) + + ctx := context.Background() + t.Run("disbursements list empty", func(t *testing.T) { + resultWithTotal, err := service.GetDisbursementsWithCount(ctx, &data.QueryParams{}) + require.NoError(t, err) + require.Equal(t, 0, resultWithTotal.Total) + result, ok := resultWithTotal.Result.([]*data.Disbursement) + require.True(t, ok) + require.Equal(t, 0, len(result)) + }) + + t.Run("get disbursements successfully", func(t *testing.T) { + // create disbursements + d1 := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{Name: "d1"}) + d2 := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{Name: "d2"}) + + resultWithTotal, err := service.GetDisbursementsWithCount(ctx, &data.QueryParams{SortOrder: "asc", SortBy: "name"}) + require.NoError(t, err) + require.Equal(t, 2, resultWithTotal.Total) + result, ok := resultWithTotal.Result.([]*data.Disbursement) + require.True(t, ok) + require.Equal(t, 2, len(result)) + require.Equal(t, d1.ID, result[0].ID) + require.Equal(t, d2.ID, result[1].ID) + }) +} + +func Test_DisbursementService_GetDisbursementReceiversWithCount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + service := NewDisbursementService(models, models.DBConnectionPool, nil) + disbursement := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{}) + + ctx := context.Background() + t.Run("disbursements not found", func(t *testing.T) { + resultWithTotal, err := service.GetDisbursementReceiversWithCount(ctx, "wrong-id", &data.QueryParams{}) + require.ErrorIs(t, err, ErrDisbursementNotFound) + require.Nil(t, resultWithTotal) + }) + + t.Run("disbursements receivers list empty", func(t *testing.T) { + resultWithTotal, err := service.GetDisbursementReceiversWithCount(ctx, disbursement.ID, &data.QueryParams{}) + require.NoError(t, err) + require.Equal(t, 0, resultWithTotal.Total) + result, ok := resultWithTotal.Result.([]*data.DisbursementReceiver) + require.True(t, ok) + require.Equal(t, 0, len(result)) + }) + + t.Run("get disbursement receivers successfully", func(t *testing.T) { + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + rwDraft1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, disbursement.Wallet.ID, data.DraftReceiversWalletStatus) + rwDraft2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, disbursement.Wallet.ID, data.DraftReceiversWalletStatus) + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwDraft1, + Disbursement: disbursement, + Asset: *disbursement.Asset, + Amount: "100", + Status: data.DraftPaymentStatus, + }) + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwDraft2, + Disbursement: disbursement, + Asset: *disbursement.Asset, + Amount: "200", + Status: data.DraftPaymentStatus, + }) + + resultWithTotal, err := service.GetDisbursementReceiversWithCount(ctx, disbursement.ID, &data.QueryParams{}) + require.NoError(t, err) + require.Equal(t, 2, resultWithTotal.Total) + result, ok := resultWithTotal.Result.([]*data.DisbursementReceiver) + require.True(t, ok) + require.Equal(t, 2, len(result)) + }) +} + +func Test_DisbursementService_StartDisbursement(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + token := "token" + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + authManagerMock := &auth.AuthManagerMock{} + service := NewDisbursementService(models, models.DBConnectionPool, authManagerMock) + + // create fixtures + wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool) + asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC) + country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUKR) + + // create disbursements + draftDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "draft disbursement", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "ready disbursement", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + // create disbursement receivers + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver3 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver4 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + receiverIds := []string{receiver1.ID, receiver2.ID, receiver3.ID, receiver4.ID} + + rwDraft1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.DraftReceiversWalletStatus) + rwDraft2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.DraftReceiversWalletStatus) + rwReady := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver3.ID, wallet.ID, data.ReadyReceiversWalletStatus) + rwRegistered := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver4.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwDraft1, + Disbursement: readyDisbursement, + Asset: *asset, + Amount: "100", + Status: data.DraftPaymentStatus, + }) + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwDraft2, + Disbursement: readyDisbursement, + Asset: *asset, + Amount: "200", + Status: data.DraftPaymentStatus, + }) + payment3 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwReady, + Disbursement: readyDisbursement, + Asset: *asset, + Amount: "300", + Status: data.DraftPaymentStatus, + }) + payment4 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwRegistered, + Disbursement: readyDisbursement, + Asset: *asset, + Amount: "400", + Status: data.DraftPaymentStatus, + }) + + payments := []*data.Payment{payment1, payment2, payment3, payment4} + + t.Run("disbursement doesn't exist", func(t *testing.T) { + id := "5e1f1c7f5b6c9c0001c1b1b1" + + err = service.StartDisbursement(context.Background(), id) + require.ErrorIs(t, err, ErrDisbursementNotFound) + }) + + t.Run("disbursement not ready to start", func(t *testing.T) { + err = service.StartDisbursement(context.Background(), draftDisbursement.ID) + require.ErrorIs(t, err, ErrDisbursementNotReadyToStart) + }) + + t.Run("disbursement can't be started by its creator", func(t *testing.T) { + userID := "9ae68f09-cad9-4311-9758-4ff59d2e9e6d" + statusHistory := []data.DisbursementStatusHistoryEntry{ + { + Status: data.DraftDisbursementStatus, + UserID: userID, + }, + { + Status: data.ReadyDisbursementStatus, + UserID: userID, + }, + } + disbursement := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement #1", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + StatusHistory: statusHistory, + }) + + user := &auth.User{ + ID: userID, + Email: "email@email.com", + } + + authManagerMock. + On("GetUser", ctx, token). + Return(user, nil). + Once() + + // Enable approval workflow for org. + isApprovalRequired := true + err = models.Organizations.Update(ctx, &data.OrganizationUpdate{IsApprovalRequired: &isApprovalRequired}) + require.NoError(t, err) + + err = service.StartDisbursement(ctx, disbursement.ID) + require.ErrorIs(t, err, ErrDisbursementStartedByCreator) + + // rollback changes + isApprovalRequired = false + err = models.Organizations.Update(ctx, &data.OrganizationUpdate{IsApprovalRequired: &isApprovalRequired}) + require.NoError(t, err) + }) + + t.Run("disbursement started with approval workflow", func(t *testing.T) { + userID := "9ae68f09-cad9-4311-9758-4ff59d2e9e6d" + statusHistory := []data.DisbursementStatusHistoryEntry{ + { + Status: data.DraftDisbursementStatus, + UserID: userID, + }, + { + Status: data.ReadyDisbursementStatus, + UserID: userID, + }, + } + disbursement := data.CreateDisbursementFixture(t, context.Background(), dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement #2", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + StatusHistory: statusHistory, + }) + + user := &auth.User{ + ID: "another user id", + Email: "email@email.com", + } + + authManagerMock. + On("GetUser", ctx, token). + Return(user, nil). + Once() + + // Enable approval workflow for org. + isApprovalRequired := true + err = models.Organizations.Update(ctx, &data.OrganizationUpdate{IsApprovalRequired: &isApprovalRequired}) + require.NoError(t, err) + + err = service.StartDisbursement(ctx, disbursement.ID) + require.NoError(t, err) + + // check disbursement status + disbursement, err = models.Disbursements.Get(context.Background(), models.DBConnectionPool, disbursement.ID) + require.NoError(t, err) + require.Equal(t, data.StartedDisbursementStatus, disbursement.Status) + + // rollback changes + isApprovalRequired = false + err = models.Organizations.Update(ctx, &data.OrganizationUpdate{IsApprovalRequired: &isApprovalRequired}) + require.NoError(t, err) + }) + + t.Run("disbursement started", func(t *testing.T) { + user := &auth.User{ + ID: "user-id", + Email: "email@email.com", + } + + authManagerMock. + On("GetUser", ctx, token). + Return(user, nil). + Once() + + err = service.StartDisbursement(ctx, readyDisbursement.ID) + require.NoError(t, err) + + // check disbursement status + disbursement, err := models.Disbursements.Get(context.Background(), models.DBConnectionPool, readyDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.StartedDisbursementStatus, disbursement.Status) + + // check disbursement history + require.Equal(t, disbursement.StatusHistory[1].UserID, user.ID) + + // check receivers wallets status + receiverWallets, err := models.ReceiverWallet.GetByReceiverIDsAndWalletID(ctx, models.DBConnectionPool, receiverIds, wallet.ID) + require.NoError(t, err) + require.Equal(t, 4, len(receiverWallets)) + rwExpectedStatuses := map[string]data.ReceiversWalletStatus{ + rwDraft1.ID: data.ReadyReceiversWalletStatus, + rwDraft2.ID: data.ReadyReceiversWalletStatus, + rwReady.ID: data.ReadyReceiversWalletStatus, + rwRegistered.ID: data.RegisteredReceiversWalletStatus, + } + for _, rw := range receiverWallets { + require.Equal(t, rwExpectedStatuses[rw.ID], rw.Status) + } + + // check payments status + for _, p := range payments { + payment, err := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, err) + require.Equal(t, data.ReadyPaymentStatus, payment.Status) + } + }) + + authManagerMock.AssertExpectations(t) +} + +func Test_DisbursementService_PauseDisbursement(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + models, outerErr := data.NewModels(dbConnectionPool) + require.NoError(t, outerErr) + + token := "token" + ctx := context.WithValue(context.Background(), middleware.TokenContextKey, token) + + user := &auth.User{ + ID: "user-id", + Email: "email@email.com", + } + authManagerMock := &auth.AuthManagerMock{} + authManagerMock. + On("GetUser", ctx, token). + Return(user, nil) + + service := NewDisbursementService(models, models.DBConnectionPool, authManagerMock) + + // create fixtures + wallet := data.CreateDefaultWalletFixture(t, ctx, dbConnectionPool) + asset := data.GetAssetFixture(t, ctx, dbConnectionPool, data.FixtureAssetUSDC) + country := data.GetCountryFixture(t, ctx, dbConnectionPool, data.FixtureCountryUSA) + + // create disbursements + readyDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "ready disbursement", + Status: data.ReadyDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "started disbursement", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + // create disbursement receivers + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver3 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver4 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + rwRegistered1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rwRegistered2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rwRegistered3 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver3.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rwRegistered4 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver4.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + paymentPending1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwRegistered1, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "100", + Status: data.PendingPaymentStatus, + }) + paymentPending2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwRegistered2, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "200", + Status: data.PendingPaymentStatus, + }) + paymentReady1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwRegistered3, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "300", + Status: data.ReadyPaymentStatus, + }) + paymentReady2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwRegistered4, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "400", + Status: data.ReadyPaymentStatus, + }) + + t.Run("disbursement doesn't exist", func(t *testing.T) { + id := "5e1f1c7f5b6c9c0001c1b1b1" + + err := service.PauseDisbursement(ctx, id) + require.ErrorIs(t, err, ErrDisbursementNotFound) + }) + + t.Run("disbursement not ready to pause", func(t *testing.T) { + err := service.PauseDisbursement(ctx, readyDisbursement.ID) + require.ErrorIs(t, err, ErrDisbursementNotReadyToPause) + }) + + t.Run("disbursement paused", func(t *testing.T) { + err := service.PauseDisbursement(ctx, startedDisbursement.ID) + require.NoError(t, err) + + // check disbursement status + disbursement, err := models.Disbursements.Get(ctx, models.DBConnectionPool, startedDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.PausedDisbursementStatus, disbursement.Status) + + // check pending payments are still pending. + for _, p := range []*data.Payment{paymentPending1, paymentPending2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PendingPaymentStatus, payment.Status) + } + + // check ready payments are paused. + for _, p := range []*data.Payment{paymentReady1, paymentReady2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PausedPaymentStatus, payment.Status) + } + + // change the disbursement back to started + err = service.StartDisbursement(ctx, startedDisbursement.ID) + require.NoError(t, err) + + // check disbursement is started again + disbursement, err = models.Disbursements.Get(ctx, models.DBConnectionPool, startedDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.StartedDisbursementStatus, disbursement.Status) + }) + + t.Run("start -> pause -> start -> pause", func(t *testing.T) { + // 1. Pause Disbursement + err := service.PauseDisbursement(ctx, startedDisbursement.ID) + require.NoError(t, err) + + // check disbursement is paused + disbursement, err := models.Disbursements.Get(ctx, models.DBConnectionPool, startedDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.PausedDisbursementStatus, disbursement.Status) + + // check pending payments are still pending. + for _, p := range []*data.Payment{paymentPending1, paymentPending2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PendingPaymentStatus, payment.Status) + } + + // check ready payments are paused. + for _, p := range []*data.Payment{paymentReady1, paymentReady2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PausedPaymentStatus, payment.Status) + } + + // 2. Start disbursement again + err = service.StartDisbursement(ctx, startedDisbursement.ID) + require.NoError(t, err) + + // check disbursement is started again + disbursement, err = models.Disbursements.Get(ctx, models.DBConnectionPool, startedDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.StartedDisbursementStatus, disbursement.Status) + + // check pending payments are still pending. + for _, p := range []*data.Payment{paymentPending1, paymentPending2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PendingPaymentStatus, payment.Status) + } + + // check paused payments are back to ready. + for _, p := range []*data.Payment{paymentReady1, paymentReady2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.ReadyPaymentStatus, payment.Status) + } + + // 3. Pause disbursement again + err = service.PauseDisbursement(ctx, startedDisbursement.ID) + require.NoError(t, err) + + // check disbursement is paused + disbursement, err = models.Disbursements.Get(ctx, models.DBConnectionPool, startedDisbursement.ID) + require.NoError(t, err) + require.Equal(t, data.PausedDisbursementStatus, disbursement.Status) + + // check pending payments are still pending. + for _, p := range []*data.Payment{paymentPending1, paymentPending2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PendingPaymentStatus, payment.Status) + } + + // check ready payments are paused again. + for _, p := range []*data.Payment{paymentReady1, paymentReady2} { + payment, innerErr := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, innerErr) + require.Equal(t, data.PausedPaymentStatus, payment.Status) + } + }) + + authManagerMock.AssertExpectations(t) +} diff --git a/internal/services/payment_service.go b/internal/services/payment_service.go new file mode 100644 index 000000000..0d60f81da --- /dev/null +++ b/internal/services/payment_service.go @@ -0,0 +1,50 @@ +package services + +import ( + "context" + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type PaymentService struct { + models *data.Models + dbConnectionPool db.DBConnectionPool +} + +// NewPaymentService creates a new PaymentService +func NewPaymentService(models *data.Models, dbConnectionPool db.DBConnectionPool) *PaymentService { + return &PaymentService{ + models: models, + dbConnectionPool: dbConnectionPool, + } +} + +type PaymentsPaginatedResponse struct { + TotalPayments int + Payments []data.Payment +} + +// GetPaymentsWithCount creates a new DB transaction to get payments and total payments filtered by query params. +func (s *PaymentService) GetPaymentsWithCount(ctx context.Context, queryParams *data.QueryParams) (*PaymentsPaginatedResponse, error) { + return db.RunInTransactionWithResult(ctx, s.dbConnectionPool, nil, func(dbTx db.DBTransaction) (response *PaymentsPaginatedResponse, innerErr error) { + totalPayments, innerErr := s.models.Payment.Count(ctx, queryParams, dbTx) + if innerErr != nil { + return nil, fmt.Errorf("error counting payments: %w", innerErr) + } + + var payments []data.Payment + if totalPayments != 0 { + payments, innerErr = s.models.Payment.GetAll(ctx, queryParams, dbTx) + if innerErr != nil { + return nil, fmt.Errorf("error querying payments: %w", innerErr) + } + } + + return &PaymentsPaginatedResponse{ + TotalPayments: totalPayments, + Payments: payments, + }, nil + }) +} diff --git a/internal/services/payment_service_test.go b/internal/services/payment_service_test.go new file mode 100644 index 000000000..b2dabf7a5 --- /dev/null +++ b/internal/services/payment_service_test.go @@ -0,0 +1,106 @@ +package services + +import ( + "context" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetPaymentsWithCount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + service := NewPaymentService(models, dbConnectionPool) + + t.Run("0 payments created", func(t *testing.T) { + response, err := service.GetPaymentsWithCount(ctx, &data.QueryParams{}) + require.NoError(t, err) + + assert.Equal(t, response.TotalPayments, 0) + assert.Equal(t, response.Payments, []data.Payment(nil)) + }) + + t.Run("error invalid payment status", func(t *testing.T) { + _, err := service.GetPaymentsWithCount(ctx, &data.QueryParams{ + Filters: map[data.FilterKey]interface{}{ + data.FilterKeyStatus: "INVALID", + }, + }) + require.EqualError(t, err, `running atomic function in RunInTransactionWithResult: error counting payments: error counting payments: pq: invalid input value for enum payment_status: "INVALID"`) + }) + + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.DraftReceiversWalletStatus) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.DraftDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.DraftPaymentStatus, + StatusHistory: []data.PaymentStatusHistoryEntry{ + { + Status: data.DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + t.Run("return payment", func(t *testing.T) { + response, err := service.GetPaymentsWithCount(ctx, &data.QueryParams{}) + require.NoError(t, err) + + assert.Equal(t, response.TotalPayments, 1) + assert.Equal(t, response.Payments, []data.Payment{*payment}) + }) + + t.Run("return multiple payments", func(t *testing.T) { + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "50", + Status: data.DraftPaymentStatus, + StatusHistory: []data.PaymentStatusHistoryEntry{ + { + Status: data.DraftPaymentStatus, + StatusMessage: "", + Timestamp: time.Now(), + }, + }, + Disbursement: disbursement, + Asset: *asset, + ReceiverWallet: receiverWallet, + }) + + response, err := service.GetPaymentsWithCount(ctx, &data.QueryParams{}) + require.NoError(t, err) + + assert.Equal(t, response.TotalPayments, 2) + assert.Equal(t, response.Payments, []data.Payment{*payment2, *payment}) + }) +} diff --git a/internal/services/result_with_total.go b/internal/services/result_with_total.go new file mode 100644 index 000000000..b97e91dff --- /dev/null +++ b/internal/services/result_with_total.go @@ -0,0 +1,13 @@ +package services + +type ResultWithTotal struct { + Total int + Result interface{} +} + +func NewResultWithTotal(total int, result interface{}) *ResultWithTotal { + return &ResultWithTotal{ + Total: total, + Result: result, + } +} diff --git a/internal/services/result_with_total_test.go b/internal/services/result_with_total_test.go new file mode 100644 index 000000000..99bf1c577 --- /dev/null +++ b/internal/services/result_with_total_test.go @@ -0,0 +1,17 @@ +package services + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_ResultWithTotal_NewResultWithTotal(t *testing.T) { + total := 10 + result := []string{"apple", "banana", "cherry"} + + resultWithTotal := NewResultWithTotal(total, result) + + require.Equal(t, total, resultWithTotal.Total) + require.Equal(t, result, resultWithTotal.Result) +} diff --git a/internal/services/send_payments_service.go b/internal/services/send_payments_service.go new file mode 100644 index 000000000..addea76fb --- /dev/null +++ b/internal/services/send_payments_service.go @@ -0,0 +1,146 @@ +package services + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + txSubStore "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type SendPaymentsServiceInterface interface { + SendBatchPayments(ctx context.Context, batchSize int) error +} + +type SendPaymentsService struct { + sdpModels *data.Models + tssModel *txSubStore.TransactionModel +} + +// SendBatchPayments sends payments in batches +func (s SendPaymentsService) SendBatchPayments(ctx context.Context, batchSize int) error { + err := db.RunInTransaction(ctx, s.sdpModels.DBConnectionPool, nil, func(dbTx db.DBTransaction) error { + return s.sendBatchPayments(ctx, dbTx, batchSize) + }) + if err != nil { + return fmt.Errorf("error sending payments: %w", err) + } + + return nil +} + +// sendBatchPayments sends payments in batches in a transaction +func (s SendPaymentsService) sendBatchPayments(ctx context.Context, dbTx db.DBTransaction, batchSize int) error { + // 1. Get payments that are ready to be sent. This will lock the rows. + // Payments Ready to be sent means: + // a. Payment is in `READY` status + // b. Receiver Wallet is in `REGISTERED` status + // c. Disbursement is in `STARTED` status + payments, err := s.sdpModels.Payment.GetBatchForUpdate(ctx, dbTx, batchSize) + if err != nil { + return fmt.Errorf("error getting payments ready to be sent: %w", err) + } + + var transactions []txSubStore.Transaction + var failedPayments []*data.Payment + var pendingPayments []*data.Payment + for _, payment := range payments { + // 2. Validate that payments are ready to be sent + if errValidation := validatePaymentReadyForSending(payment); errValidation != nil { + // if payment is not ready for sending, we will mark it as failed later. + failedPayments = append(failedPayments, payment) + log.Ctx(ctx).Errorf("Payment %s is not ready for sending. Error:%s", payment.ID, errValidation.Error()) + continue + } + + // TODO: change TSS to use string amount [SDP-483] + amount, parseErr := strconv.ParseFloat(payment.Amount, 64) + if parseErr != nil { + return fmt.Errorf("error parsing payment amount %s for payment %s: %w", payment.Amount, payment.ID, parseErr) + } + transaction := txSubStore.Transaction{ + ExternalID: payment.ID, + AssetCode: payment.Asset.Code, + AssetIssuer: payment.Asset.Issuer, + Amount: amount, + Destination: payment.ReceiverWallet.StellarAddress, + } + transactions = append(transactions, transaction) + pendingPayments = append(pendingPayments, payment) + } + + // 3. Persist data in Transactions table + _, err = s.tssModel.BulkInsert(ctx, dbTx, transactions) + if err != nil { + return fmt.Errorf("error inserting transactions: %w", err) + } + // 4. Update payment statuses to `Pending` + err = s.sdpModels.Payment.UpdateStatuses(ctx, dbTx, pendingPayments, data.PendingPaymentStatus) + if err != nil { + return fmt.Errorf("error updating payment statuses to Pending: %w", err) + } + + // 5. Update failed payments statuses to `Failed` + if len(failedPayments) != 0 { + err = s.sdpModels.Payment.UpdateStatuses(ctx, dbTx, failedPayments, data.FailedPaymentStatus) + if err != nil { + return fmt.Errorf("error updating payment statuses to Failed: %w", err) + } + } + return nil +} + +// ValidateReadyForSending validate that payment is ready for sending +// 1. Check Statuses of Payment, Receiver Wallet, and Disbursement +// 2. Check required fields are not empty. +func validatePaymentReadyForSending(p *data.Payment) error { + // check statuses + if p.Status != data.ReadyPaymentStatus { + return fmt.Errorf("payment %s is not in %s state", p.ID, data.ReadyPaymentStatus) + } + if p.ReceiverWallet.Status != data.RegisteredReceiversWalletStatus { + return fmt.Errorf("receiver wallet %s for payment %s is not in %s state", p.ReceiverWallet.ID, p.ID, data.RegisteredReceiversWalletStatus) + } + if p.Disbursement.Status != data.StartedDisbursementStatus { + return fmt.Errorf("disbursement %s for payment %s is not in %s state", p.Disbursement.ID, p.ID, data.StartedDisbursementStatus) + } + + // verify that transaction required fields are not empty + // 1. payment.ID is used as transaction.ExternalID + if strings.TrimSpace(p.ID) == "" { + return fmt.Errorf("payment ID is empty for Payment") + } + // 2. payment.asset.Code is used as transaction.AssetCode + if strings.TrimSpace(p.Asset.Code) == "" { + return fmt.Errorf("payment asset code is empty for payment %s", p.ID) + } + // 3. payment.asset.Issuer is used as transaction.AssetIssuer + if strings.TrimSpace(p.Asset.Issuer) == "" { + return fmt.Errorf("payment asset issuer is empty for payment %s", p.ID) + } + // 4. payment.Amount is used as transaction.Amount + if err := utils.ValidateAmount(p.Amount); err != nil { + return fmt.Errorf("payment amount is invalid for payment %s", p.ID) + } + // 5. payment.ReceiverWallet.StellarAddress is used as transaction.Destination + if strings.TrimSpace(p.ReceiverWallet.StellarAddress) == "" { + return fmt.Errorf("payment receiver wallet stellar address is empty for payment %s", p.ID) + } + + return nil +} + +func NewSendPaymentsService(models *data.Models) *SendPaymentsService { + return &SendPaymentsService{ + sdpModels: models, + tssModel: txSubStore.NewTransactionModel(models.DBConnectionPool), + } +} + +// Making sure that ServerService implements ServerServiceInterface +var _ SendPaymentsServiceInterface = (*SendPaymentsService)(nil) diff --git a/internal/services/send_payments_service_test.go b/internal/services/send_payments_service_test.go new file mode 100644 index 000000000..c018b6ed0 --- /dev/null +++ b/internal/services/send_payments_service_test.go @@ -0,0 +1,396 @@ +package services + +import ( + "context" + "strconv" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + txSubStore "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SendPaymentsService_SendBatchPayments(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + tssModel := txSubStore.NewTransactionModel(models.DBConnectionPool) + + service := NewSendPaymentsService(models) + ctx := context.Background() + + // create fixtures + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, + "My Wallet", + "https://www.wallet.com", + "www.wallet.com", + "wallet1://") + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, + "USDC", + "GDUCE34WW5Z34GMCEPURYANUCUP47J6NORJLKC6GJNMDLN4ZI4PMI2MG") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, + "FRA", + "France") + + // create disbursements + startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "ready disbursement", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + // create disbursement receivers + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver3 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver4 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + rw1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rw2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rw3 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver3.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rwReady := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver4.ID, wallet.ID, data.ReadyReceiversWalletStatus) + + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rw1, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "100", + Status: data.ReadyPaymentStatus, + }) + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rw2, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "200", + Status: data.ReadyPaymentStatus, + }) + payment3 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rw3, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "300", + Status: data.ReadyPaymentStatus, + }) + payment4 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + ReceiverWallet: rwReady, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "400", + Status: data.ReadyPaymentStatus, + }) + + t.Run("send payments", func(t *testing.T) { + err = service.SendBatchPayments(ctx, 5) + require.NoError(t, err) + + // payments that can be sent + for _, p := range []*data.Payment{payment1, payment2, payment3} { + payment, err := models.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, err) + require.Equal(t, data.PendingPaymentStatus, payment.Status) + } + + // payments that can't be sent (rw status is not REGISTERED) + payment, err := models.Payment.Get(ctx, payment4.ID, dbConnectionPool) + require.NoError(t, err) + require.Equal(t, data.ReadyPaymentStatus, payment.Status) + + // validate transactions + transactions, err := tssModel.GetAllByPaymentIDs(ctx, []string{payment1.ID, payment2.ID, payment3.ID, payment4.ID}) + require.NoError(t, err) + require.Len(t, transactions, 3) + + expectedPayments := map[string]*data.Payment{ + payment1.ID: payment1, + payment2.ID: payment2, + payment3.ID: payment3, + } + + for _, tx := range transactions { + require.Equal(t, txSubStore.TransactionStatusPending, tx.Status) + require.Equal(t, expectedPayments[tx.ExternalID].Asset.Code, tx.AssetCode) + require.Equal(t, expectedPayments[tx.ExternalID].Asset.Issuer, tx.AssetIssuer) + require.Equal(t, expectedPayments[tx.ExternalID].Amount, strconv.FormatFloat(tx.Amount, 'f', 7, 32)) + require.Equal(t, expectedPayments[tx.ExternalID].ReceiverWallet.StellarAddress, tx.Destination) + require.Equal(t, expectedPayments[tx.ExternalID].ID, tx.ExternalID) + } + }) +} + +func Test_SendPaymentsService_ValidatePaymentReadyForSending(t *testing.T) { + testCases := []struct { + name string + payment *data.Payment + expectedError string + }{ + { + name: "valid payment", + payment: &data.Payment{ + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + StellarAddress: "destination_1", + }, + Disbursement: &data.Disbursement{ + Status: data.StartedDisbursementStatus, + }, + ID: "1", + Asset: data.Asset{ + Code: "USDC", + Issuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVN", + }, + Amount: "100.0", + }, + expectedError: "", + }, + { + name: "invalid payment status", + payment: &data.Payment{ + ID: "123", + Status: data.PendingPaymentStatus, + }, + expectedError: "payment 123 is not in READY state", + }, + { + name: "invalid receiver wallet status", + payment: &data.Payment{ + ID: "123", + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + ID: "321", + Status: data.ReadyReceiversWalletStatus, + }, + }, + expectedError: "receiver wallet 321 for payment 123 is not in REGISTERED state", + }, + { + name: "invalid disbursement status", + payment: &data.Payment{ + ID: "123", + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + }, + Disbursement: &data.Disbursement{ + ID: "321", + Status: data.ReadyDisbursementStatus, + }, + }, + expectedError: "disbursement 321 for payment 123 is not in STARTED state", + }, + { + name: "payment ID is empty", + payment: &data.Payment{ + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + }, + Disbursement: &data.Disbursement{ + Status: data.StartedDisbursementStatus, + }, + }, + expectedError: "payment ID is empty for Payment", + }, + { + name: "payment asset code is empty", + payment: &data.Payment{ + ID: "123", + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + }, + Disbursement: &data.Disbursement{ + Status: data.StartedDisbursementStatus, + }, + }, + expectedError: "payment asset code is empty for payment 123", + }, + { + name: "payment asset issuer is empty", + payment: &data.Payment{ + ID: "123", + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + }, + Disbursement: &data.Disbursement{ + Status: data.StartedDisbursementStatus, + }, + Asset: data.Asset{ + Code: "USDC", + }, + }, + expectedError: "payment asset issuer is empty for payment 123", + }, + { + name: "payment amount is invalid", + payment: &data.Payment{ + ID: "123", + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + }, + Disbursement: &data.Disbursement{ + Status: data.StartedDisbursementStatus, + }, + Asset: data.Asset{ + Code: "USDC", + Issuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVN", + }, + }, + expectedError: "payment amount is invalid for payment 123", + }, + { + name: "payment receiver wallet stellar address is empty", + payment: &data.Payment{ + ID: "123", + Status: data.ReadyPaymentStatus, + ReceiverWallet: &data.ReceiverWallet{ + Status: data.RegisteredReceiversWalletStatus, + }, + Disbursement: &data.Disbursement{ + Status: data.StartedDisbursementStatus, + }, + Asset: data.Asset{ + Code: "USDC", + Issuer: "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVN", + }, + Amount: "100.0", + }, + expectedError: "payment receiver wallet stellar address is empty for payment 123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validatePaymentReadyForSending(tc.payment) + if tc.expectedError == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.expectedError) + } + }) + } +} + +func Test_SendPaymentsService_RetryPayment(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + tssModel := txSubStore.NewTransactionModel(models.DBConnectionPool) + + service := NewSendPaymentsService(models) + + // clean test db + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + + // create fixtures + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GDUCE34WW5Z34GMCEPURYANUCUP47J6NORJLKC6GJNMDLN4ZI4PMI2MG") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "started disbursement", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "100", + Status: data.ReadyPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + err = service.SendBatchPayments(ctx, 1) + require.NoError(t, err) + + paymentDB, err := models.Payment.Get(ctx, payment.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.PendingPaymentStatus, paymentDB.Status) + + transactions, err := tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + assert.Len(t, transactions, 1) + + transaction := transactions[0] + assert.Equal(t, payment.ID, transaction.ExternalID) + assert.Equal(t, txSubStore.TransactionStatusPending, transaction.Status) + + // Marking the transaction as failed + transaction.Status = txSubStore.TransactionStatusProcessing + _, err = tssModel.UpdateStatusToError(ctx, *transaction, "Failing Test") + require.NoError(t, err) + + transactions, err = tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + assert.Len(t, transactions, 1) + + transaction = transactions[0] + assert.Equal(t, payment.ID, transaction.ExternalID) + assert.Equal(t, txSubStore.TransactionStatusError, transaction.Status) + + err = models.Payment.Update(ctx, dbConnectionPool, paymentDB, &data.PaymentUpdate{ + Status: data.FailedPaymentStatus, + StellarTransactionID: "stellar-transaction-id-2", + }) + require.NoError(t, err) + paymentDB, err = models.Payment.Get(ctx, paymentDB.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.FailedPaymentStatus, paymentDB.Status) + + err = models.Payment.RetryFailedPayments(ctx, "email@test.com", paymentDB.ID) + require.NoError(t, err) + paymentDB, err = models.Payment.Get(ctx, paymentDB.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.ReadyPaymentStatus, paymentDB.Status) + + // insert a new transaction for the same payment + err = service.SendBatchPayments(ctx, 1) + require.NoError(t, err) + + paymentDB, err = models.Payment.Get(ctx, payment.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.PendingPaymentStatus, paymentDB.Status) + + transactions, err = tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + assert.Len(t, transactions, 2) + + transaction1 := transactions[0] + transaction2 := transactions[1] + assert.Equal(t, txSubStore.TransactionStatusError, transaction1.Status) + assert.Equal(t, txSubStore.TransactionStatusPending, transaction2.Status) +} diff --git a/internal/services/send_receiver_wallets_invite_service.go b/internal/services/send_receiver_wallets_invite_service.go new file mode 100644 index 000000000..3fb0f6ace --- /dev/null +++ b/internal/services/send_receiver_wallets_invite_service.go @@ -0,0 +1,329 @@ +package services + +import ( + "context" + "fmt" + "html/template" + "net/url" + "path" + "strings" + + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type SendReceiverWalletInviteService struct { + messengerClient message.MessengerClient + models *data.Models + anchorPlatformBaseSepURL string + minDaysBetweenRetries int + maxRetries int + sep10SigningPrivateKey string + crashTrackerClient crashtracker.CrashTrackerClient +} + +func (s SendReceiverWalletInviteService) validate() error { + if s.messengerClient == nil { + return fmt.Errorf("messenger client can't be nil") + } + + if s.anchorPlatformBaseSepURL == "" { + return fmt.Errorf("anchorPlatformBaseSepURL can't be empty") + } + + return nil +} + +// SendInvite sends the invitation’s deep link to the wallet’s application. +// The approach to sending the invitation is to send the deep link for each asset the wallet will pay based on the payment. +// For instance, the Wallet Foo is in two Ready Payments, one with USDC and the other with EUROC. +// So the receiver who has a Stellar Address pending registration (status:READY) in this wallet will receive both invites for USDC and EUROC. +// This would not impact the user receiving both token amounts. It's only for the registration process. +func (s SendReceiverWalletInviteService) SendInvite(ctx context.Context) error { + // Get the organization entry to get the Org name and SMSRegistrationMessageTemplate + organization, err := s.models.Organizations.Get(ctx) + if err != nil { + return fmt.Errorf("error getting organization: %w", err) + } + + // Execute the template early so we avoid hitting the database to query the other info + msgTemplate, err := template.New("").Parse(organization.SMSRegistrationMessageTemplate) + if err != nil { + return fmt.Errorf("error parsing SMS registration message template: %w", err) + } + + wallets, err := s.models.Wallets.GetAll(ctx) + if err != nil { + return fmt.Errorf("error getting all wallets: %w", err) + } + + walletsMap := make(map[string]data.Wallet, len(wallets)) + for _, wallet := range wallets { + walletsMap[wallet.ID] = wallet + } + + receiverWallets, err := s.models.ReceiverWallet.GetAllPendingRegistration(ctx, s.minDaysBetweenRetries, s.maxRetries) + if err != nil { + return fmt.Errorf("error getting receiver wallets pending registration: %w", err) + } + + receiverWalletsAsset, err := s.models.Assets.GetAssetsPerReceiverWallet(ctx, receiverWallets...) + if err != nil { + return fmt.Errorf("error getting all assets: %w", err) + } + + msgsToInsert := []*data.MessageInsert{} + // TODO: improve this code adding go routines + for _, rwa := range receiverWalletsAsset { + wallet := walletsMap[rwa.WalletID] + + wdl := WalletDeepLink{ + DeepLink: wallet.DeepLinkSchema, + AnchorPlatformBaseSepURL: s.anchorPlatformBaseSepURL, + OrganizationName: organization.Name, + AssetCode: rwa.Asset.Code, + AssetIssuer: rwa.Asset.Issuer, + } + + registrationLink, err := wdl.GetSignedRegistrationLink(s.sep10SigningPrivateKey) + if err != nil { + log.Ctx(ctx).Errorf( + "error getting signed registration link to receiver wallet ID %s for wallet ID %s and asset ID %s: %s", + rwa.ReceiverWallet.ID, wallet.ID, rwa.Asset.ID, err.Error(), + ) + continue + } + + content := new(strings.Builder) + err = msgTemplate.Execute(content, struct { + OrganizationName string + RegistrationLink template.HTML + }{ + OrganizationName: organization.Name, + RegistrationLink: template.HTML(registrationLink), + }) + if err != nil { + return fmt.Errorf("error executing registration message template: %w", err) + } + + msg := message.Message{ + ToPhoneNumber: rwa.ReceiverWallet.Receiver.PhoneNumber, + Message: content.String(), + } + + receiverWalletID := rwa.ReceiverWallet.ID + messageType := s.messengerClient.MessengerType() + msgToInsert := &data.MessageInsert{ + Type: messageType, + AssetID: nil, + ReceiverID: rwa.ReceiverWallet.Receiver.ID, + WalletID: wallet.ID, + ReceiverWalletID: &receiverWalletID, + TextEncrypted: content.String(), + } + + // We assume that the message will be sent at first + msgToInsert.Status = data.SuccessMessageStatus + if err := s.messengerClient.SendMessage(msg); err != nil { + msg := fmt.Sprintf( + "error sending message to receiver ID %s for receiver wallet ID %s using messenger type %s", + rwa.ReceiverWallet.Receiver.ID, rwa.ReceiverWallet.ID, messageType, + ) + // call crash tracker client to log and report error + s.crashTrackerClient.LogAndReportErrors(ctx, err, msg) + msgToInsert.Status = data.FailureMessageStatus + } + + msgsToInsert = append(msgsToInsert, msgToInsert) + } + + if err := s.models.Message.BulkInsert(ctx, msgsToInsert); err != nil { + return fmt.Errorf("error inserting messages in the database: %w", err) + } + + return nil +} + +func NewSendReceiverWalletInviteService(models *data.Models, messengerClient message.MessengerClient, anchorPlatformBaseSepURL, sep10SigningPrivateKey string, minDaysBetweenRetries, maxRetries int, crashTrackerClient crashtracker.CrashTrackerClient) (*SendReceiverWalletInviteService, error) { + s := &SendReceiverWalletInviteService{ + messengerClient: messengerClient, + models: models, + anchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + minDaysBetweenRetries: minDaysBetweenRetries, + maxRetries: maxRetries, + sep10SigningPrivateKey: sep10SigningPrivateKey, + crashTrackerClient: crashTrackerClient, + } + + if err := s.validate(); err != nil { + return nil, fmt.Errorf("invalid service setup: %w", err) + } + + return s, nil +} + +type WalletDeepLink struct { + // DeepLink is the deep link used to open the wallet invitation link. + DeepLink string + // Route is an optional parameter that can be used to specify the route to open in the wallet, in case it's not already present in the DeepLink. + Route string // (optional) + // AnchorPlatformBaseSepURL is the base URL of the /.well-known/stellar.toml file. + AnchorPlatformBaseSepURL string + // OrganizationName is the name of the organization that is sending the invitation. + OrganizationName string + // AssetCode is the code of the Stellar asset that the receiver will be able to receive. + AssetCode string + // AssetIssuer is the issuer of the Stellar asset that the receiver will be able to receive. + AssetIssuer string +} + +// BaseURLWithRoute returns the base URL of the deep link with the route appended. +func (wdl WalletDeepLink) BaseURLWithRoute() (string, error) { + if wdl.DeepLink == "" { + return "", fmt.Errorf("DeepLink can't be empty") + } + + deepLink, err := url.Parse(wdl.DeepLink) + if err != nil { + return "", fmt.Errorf("error parsing DeepLink: %w", err) + } + + if deepLink.Scheme == "" { + deepLink.Scheme = "https" + } + + if deepLink.Host == "" && deepLink.Path == "" && wdl.Route == "" { + return "", fmt.Errorf("the deep link needs to have a valid host, or path, or route") + } + + if wdl.Route != "" { + if deepLink.Path == "" && deepLink.Host == "" { + deepLink.Path = wdl.Route + } else { + deepLink.Path = path.Join(deepLink.Path, wdl.Route) + } + } + + return deepLink.String(), nil +} + +func (wdl WalletDeepLink) TomlFileDomain() (string, error) { + if wdl.AnchorPlatformBaseSepURL == "" { + return "", fmt.Errorf("AnchorPlatformBaseSepURL can't be empty") + } + + anchorPlatformBaseSepURL := wdl.AnchorPlatformBaseSepURL + if !strings.Contains(anchorPlatformBaseSepURL, "://") { + anchorPlatformBaseSepURL = "http://" + anchorPlatformBaseSepURL + } + + anchorURL, err := url.Parse(anchorPlatformBaseSepURL) + if err != nil { + return "", fmt.Errorf("error parsing AnchorPlatformBaseSepURL '%s': %w", anchorPlatformBaseSepURL, err) + } + + return anchorURL.Hostname(), nil +} + +// validate will make sure all the parameters are set correctly. +func (wdl WalletDeepLink) validate() error { + if wdl.DeepLink == "" { + return fmt.Errorf("wallet schema can't be empty") + } + + _, err := wdl.BaseURLWithRoute() + if err != nil { + return fmt.Errorf("can't generate a valid base URL for the deep link: %w", err) + } + + if wdl.AnchorPlatformBaseSepURL == "" { + return fmt.Errorf("toml file domain can't be empty") + } + + if wdl.OrganizationName == "" { + return fmt.Errorf("organization name can't be empty") + } + + if wdl.AssetCode == "" { + return fmt.Errorf("asset code can't be empty") + } + + // not XLM: + if strings.ToUpper(wdl.AssetCode) != "XLM" { + if wdl.AssetIssuer == "" { + return fmt.Errorf("asset issuer can't be empty unless the asset code is XLM") + } + + if !strkey.IsValidEd25519PublicKey(wdl.AssetIssuer) { + return fmt.Errorf("asset issuer is not a valid Ed25519 public key %s", wdl.AssetIssuer) + } + + return nil + } + + // XLM: + if wdl.AssetIssuer != "" { + return fmt.Errorf("asset issuer should be empty for XLM, but is %s", wdl.AssetIssuer) + } + + return nil +} + +// GetUnsignedRegistrationLink creates a deep link for the wallet registration using the format below: +// ?&&. +func (wdl WalletDeepLink) GetUnsignedRegistrationLink() (string, error) { + if err := wdl.validate(); err != nil { + return "", fmt.Errorf("validating WalletDeepLink: %w", err) + } + + assetName := wdl.AssetCode + if wdl.AssetIssuer != "" { + assetName += "-" + wdl.AssetIssuer + } + + tomlFileDomain, err := wdl.TomlFileDomain() + if err != nil { + return "", fmt.Errorf("getting WalletDeepLink toml file domain: %w", err) + } + + baseURLWithRoute, err := wdl.BaseURLWithRoute() + if err != nil { + return "", fmt.Errorf("getting WalletDeepLink base URL: %w", err) + } + + u, err := url.Parse(baseURLWithRoute) + if err != nil { + return "", fmt.Errorf("parsing DeepLink: %w", err) + } + + q := u.Query() + q.Add("domain", tomlFileDomain) + q.Add("name", wdl.OrganizationName) + q.Add("asset", assetName) + + u.RawQuery = q.Encode() + + return u.String(), nil +} + +// GetSignedRegistrationLink will return the registration link accompanied with an extra query parameter containing the +// signature of the registration link, where the signature is created using the stellarSecretKey with the unsigned link +// as the message, keeping in mind that the insigned link query parameters were sorted in alphabetical order to generate +// the signature. +func (wdl WalletDeepLink) GetSignedRegistrationLink(stellarSecretKey string) (string, error) { + unsignedRegistrationLink, err := wdl.GetUnsignedRegistrationLink() + if err != nil { + return "", fmt.Errorf("error getting unsigned registration link: %w", err) + } + + signedRegistrationLink, err := utils.SignURL(stellarSecretKey, unsignedRegistrationLink) + if err != nil { + return "", fmt.Errorf("error signing registration link: %w", err) + } + + return signedRegistrationLink, nil +} diff --git a/internal/services/send_receiver_wallets_invite_service_test.go b/internal/services/send_receiver_wallets_invite_service_test.go new file mode 100644 index 000000000..4c7391661 --- /dev/null +++ b/internal/services/send_receiver_wallets_invite_service_test.go @@ -0,0 +1,709 @@ +package services + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/message" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetSignedRegistrationLink_SchemelessDeepLink(t *testing.T) { + wdl := WalletDeepLink{ + DeepLink: "api-dev.vibrantapp.com/sdp-dev", + AnchorPlatformBaseSepURL: "https://ap.localhost.com", + OrganizationName: "FOO Org", + AssetCode: "USDC", + AssetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + } + + registrationLink, err := wdl.GetSignedRegistrationLink("SCTOVDWM3A7KLTXXIV6YXL6QRVUIIG4HHHIDDKPR4JUB3DGDIKI5VGA2") + require.NoError(t, err) + wantRegistrationLink := "https://api-dev.vibrantapp.com/sdp-dev?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap.localhost.com&name=FOO+Org&signature=b40479041eea534a029c6aadf36f3bf6696aba9ff64684b558b9a412150b31fa8480ac7babcdef17cb445c1d105a761dbaa3599361c2d9e1d526fd4a5bac370a" + require.Equal(t, wantRegistrationLink, registrationLink) + + wdl = WalletDeepLink{ + DeepLink: "https://www.beansapp.com/disbursements/registration?redirect=true", + AnchorPlatformBaseSepURL: "https://ap.localhost.com", + OrganizationName: "FOO Org", + AssetCode: "USDC", + AssetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + } + + registrationLink, err = wdl.GetSignedRegistrationLink("SCTOVDWM3A7KLTXXIV6YXL6QRVUIIG4HHHIDDKPR4JUB3DGDIKI5VGA2") + require.NoError(t, err) + wantRegistrationLink = "https://www.beansapp.com/disbursements/registration?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap.localhost.com&name=FOO+Org&redirect=true&signature=8dd0a570bf5590a8e1a4983d413b5429ed504659543cf180fbf1b3ffbf0ea90083789a7c0c615d9cbddbe0c59f7555e6fd33fb5ca8f4685c821fc23ad7cd2f0d" + require.Equal(t, wantRegistrationLink, registrationLink) +} + +func Test_SendReceiverWalletInviteService(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + anchorPlatformBaseSepURL := "http://localhost:8000" + stellarSecretKey := "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5" + messengerClientMock := &message.MessengerClientMock{} + messengerClientMock. + On("MessengerType"). + Return(message.MessengerTypeTwilioSMS) + + mockCrashTrackerClient := &crashtracker.MockCrashTrackerClient{} + + ctx := context.Background() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "ATL", "Atlantis") + + wallet1 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet1", "https://wallet1.com", "www.wallet1.com", "wallet1://sdp") + wallet2 := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet2", "https://wallet2.com", "www.wallet2.com", "wallet2://sdp") + + asset1 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "FOO1", "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX") + asset2 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "FOO2", "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX") + + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Country: country, + Wallet: wallet1, + Status: data.ReadyDisbursementStatus, + Asset: asset1, + }) + + disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Country: country, + Wallet: wallet2, + Status: data.ReadyDisbursementStatus, + Asset: asset2, + }) + + t.Run("returns error when service has wrong setup", func(t *testing.T) { + _, err := NewSendReceiverWalletInviteService(models, nil, anchorPlatformBaseSepURL, stellarSecretKey, 3, 2, mockCrashTrackerClient) + assert.EqualError(t, err, "invalid service setup: messenger client can't be nil") + + _, err = NewSendReceiverWalletInviteService(models, messengerClientMock, "", stellarSecretKey, 3, 2, mockCrashTrackerClient) + assert.EqualError(t, err, "invalid service setup: anchorPlatformBaseSepURL can't be empty") + }) + + t.Run("inserts the failed sent message", func(t *testing.T) { + s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, anchorPlatformBaseSepURL, stellarSecretKey, 3, 2, mockCrashTrackerClient) + require.NoError(t, err) + + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + + rec1RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, data.ReadyReceiversWalletStatus) + data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet2.ID, data.RegisteredReceiversWalletStatus) + + rec2RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet2.ID, data.ReadyReceiversWalletStatus) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement1, + Asset: *asset1, + ReceiverWallet: rec1RW, + Amount: "1", + }) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement2, + Asset: *asset2, + ReceiverWallet: rec2RW, + Amount: "1", + }) + + walletDeepLink1 := WalletDeepLink{ + DeepLink: wallet1.DeepLinkSchema, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + OrganizationName: "MyCustomAid", + AssetCode: asset1.Code, + AssetIssuer: asset1.Issuer, + } + deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + contentWallet1 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink1) + + walletDeepLink2 := WalletDeepLink{ + DeepLink: wallet2.DeepLinkSchema, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + OrganizationName: "MyCustomAid", + AssetCode: asset2.Code, + AssetIssuer: asset2.Issuer, + } + deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + contentWallet2 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink2) + + mockErr := errors.New("unexpected error") + messengerClientMock. + On("SendMessage", message.Message{ + ToPhoneNumber: receiver1.PhoneNumber, + Message: contentWallet1, + }). + Return(errors.New("unexpected error")). + Once(). + On("SendMessage", message.Message{ + ToPhoneNumber: receiver2.PhoneNumber, + Message: contentWallet2, + }). + Return(nil). + Once() + + mockMsg := fmt.Sprintf( + "error sending message to receiver ID %s for receiver wallet ID %s using messenger type %s", + receiver1.ID, rec1RW.ID, message.MessengerTypeTwilioSMS, + ) + mockCrashTrackerClient.On("LogAndReportErrors", ctx, mockErr, mockMsg).Once() + + err = s.SendInvite(ctx) + require.NoError(t, err) + + q := ` + SELECT + type, status, receiver_id, wallet_id, receiver_wallet_id, + title_encrypted, text_encrypted, status_history + FROM + messages + WHERE + receiver_id = $1 AND wallet_id = $2 AND receiver_wallet_id = $3 + ` + var msg data.Message + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver1.ID, wallet1.ID, rec1RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver1.ID, msg.ReceiverID) + assert.Equal(t, wallet1.ID, msg.WalletID) + assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.FailureMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet1, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.FailureMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + + msg = data.Message{} + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver2.ID, wallet2.ID, rec2RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver2.ID, msg.ReceiverID) + assert.Equal(t, wallet2.ID, msg.WalletID) + assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.SuccessMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet2, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.SuccessMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + + mockCrashTrackerClient.AssertExpectations(t) + }) + + t.Run("send invite successfully", func(t *testing.T) { + s, err := NewSendReceiverWalletInviteService(models, messengerClientMock, anchorPlatformBaseSepURL, stellarSecretKey, 3, 2, mockCrashTrackerClient) + require.NoError(t, err) + + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllMessagesFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + + rec1RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet1.ID, data.ReadyReceiversWalletStatus) + data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet2.ID, data.RegisteredReceiversWalletStatus) + + rec2RW := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet2.ID, data.ReadyReceiversWalletStatus) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement1, + Asset: *asset1, + ReceiverWallet: rec1RW, + Amount: "1", + }) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement2, + Asset: *asset2, + ReceiverWallet: rec1RW, + Amount: "1", + }) + + _ = data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Status: data.ReadyPaymentStatus, + Disbursement: disbursement2, + Asset: *asset2, + ReceiverWallet: rec2RW, + Amount: "1", + }) + + walletDeepLink1 := WalletDeepLink{ + DeepLink: wallet1.DeepLinkSchema, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + OrganizationName: "MyCustomAid", + AssetCode: asset1.Code, + AssetIssuer: asset1.Issuer, + } + deepLink1, err := walletDeepLink1.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + contentWallet1 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink1) + + walletDeepLink2 := WalletDeepLink{ + DeepLink: wallet2.DeepLinkSchema, + AnchorPlatformBaseSepURL: anchorPlatformBaseSepURL, + OrganizationName: "MyCustomAid", + AssetCode: asset2.Code, + AssetIssuer: asset2.Issuer, + } + deepLink2, err := walletDeepLink2.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + contentWallet2 := fmt.Sprintf("You have a payment waiting for you from the MyCustomAid. Click %s to register.", deepLink2) + + messengerClientMock. + On("SendMessage", message.Message{ + ToPhoneNumber: receiver1.PhoneNumber, + Message: contentWallet1, + }). + Return(nil). + Once(). + On("SendMessage", message.Message{ + ToPhoneNumber: receiver1.PhoneNumber, + Message: contentWallet2, + }). + Return(nil). + Once(). + On("SendMessage", message.Message{ + ToPhoneNumber: receiver2.PhoneNumber, + Message: contentWallet2, + }). + Return(nil). + Once() + + err = s.SendInvite(ctx) + require.NoError(t, err) + + q := ` + SELECT + type, status, receiver_id, wallet_id, receiver_wallet_id, + title_encrypted, text_encrypted, status_history + FROM + messages + WHERE + receiver_id = $1 AND wallet_id = $2 AND receiver_wallet_id = $3 + ` + var msg data.Message + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver1.ID, wallet1.ID, rec1RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver1.ID, msg.ReceiverID) + assert.Equal(t, wallet1.ID, msg.WalletID) + assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.SuccessMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet1, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.SuccessMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + + msg = data.Message{} + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver1.ID, wallet2.ID, rec1RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver1.ID, msg.ReceiverID) + assert.Equal(t, wallet2.ID, msg.WalletID) + assert.Equal(t, rec1RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.SuccessMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet2, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.SuccessMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + + msg = data.Message{} + err = dbConnectionPool.GetContext(ctx, &msg, q, receiver2.ID, wallet2.ID, rec2RW.ID) + require.NoError(t, err) + + assert.Equal(t, message.MessengerTypeTwilioSMS, msg.Type) + assert.Equal(t, receiver2.ID, msg.ReceiverID) + assert.Equal(t, wallet2.ID, msg.WalletID) + assert.Equal(t, rec2RW.ID, *msg.ReceiverWalletID) + assert.Equal(t, data.SuccessMessageStatus, msg.Status) + assert.Empty(t, msg.TitleEncrypted) + assert.Equal(t, contentWallet2, msg.TextEncrypted) + assert.Len(t, msg.StatusHistory, 2) + assert.Equal(t, data.PendingMessageStatus, msg.StatusHistory[0].Status) + assert.Equal(t, data.SuccessMessageStatus, msg.StatusHistory[1].Status) + assert.Nil(t, msg.AssetID) + }) +} + +func Test_WalletDeepLink_BaseURL(t *testing.T) { + testCases := []struct { + name string + deepLink string + route string + wantResult string + wantErr error + }{ + { + name: "empty deep link and route", + wantErr: fmt.Errorf("DeepLink can't be empty"), + }, + { + name: "deep link with path [without schema] (empty route param)", + deepLink: "api-dev.vibrantapp.com", + wantResult: "https://api-dev.vibrantapp.com", + wantErr: nil, + }, + { + name: "deep link with path [without schema] (with route param)", + deepLink: "api-dev.vibrantapp.com", + route: "foo", + wantResult: "https://api-dev.vibrantapp.com/foo", + wantErr: nil, + }, + { + name: "deep link with path [without schema] {embedded route} (empty route param)", + deepLink: "api-dev.vibrantapp.com/foo", + wantResult: "https://api-dev.vibrantapp.com/foo", + wantErr: nil, + }, + { + name: "deep link with path [without schema] {embedded route} (with route param)", + deepLink: "api-dev.vibrantapp.com/foo", + route: "bar", + wantResult: "https://api-dev.vibrantapp.com/foo/bar", + wantErr: nil, + }, + { + name: "deep link with path [with schema] (empty route param)", + deepLink: "https://api-dev.vibrantapp.com", + wantResult: "https://api-dev.vibrantapp.com", + wantErr: nil, + }, + { + name: "deep link with path [with schema] (with route param)", + deepLink: "https://api-dev.vibrantapp.com", + route: "foo", + wantResult: "https://api-dev.vibrantapp.com/foo", + wantErr: nil, + }, + { + name: "deep link with path [with schema] {embedded route} (empty route param)", + deepLink: "https://api-dev.vibrantapp.com/foo", + wantResult: "https://api-dev.vibrantapp.com/foo", + wantErr: nil, + }, + { + name: "deep link with path [with schema] {embedded route} (with route param)", + deepLink: "https://api-dev.vibrantapp.com/foo", + route: "bar", + wantResult: "https://api-dev.vibrantapp.com/foo/bar", + wantErr: nil, + }, + { + name: "deep link with path [with schema] {embedded route}", + deepLink: "https://api-dev.vibrantapp.com/foo", + route: "bar", + wantResult: "https://api-dev.vibrantapp.com/foo/bar", + wantErr: nil, + }, + { + name: "deep link without path [ONLY schema]", + deepLink: "vibrant+aid://", + wantErr: fmt.Errorf("the deep link needs to have a valid host, or path, or route"), + }, + { + name: "deep link with path [with schema] {embedded route} (with route param)", + deepLink: "vibrant+aid://foo", + wantResult: "vibrant+aid://foo", + wantErr: nil, + }, + { + name: "deep link with path [with schema] {embedded route} (with route param)", + deepLink: "vibrant+aid://foo", + route: "bar", + wantResult: "vibrant+aid://foo/bar", + wantErr: nil, + }, + { + name: "deep link [with query params]", + deepLink: "vibrant+aid://foo?redirect=true", + wantResult: "vibrant+aid://foo?redirect=true", + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wdl := WalletDeepLink{ + DeepLink: tc.deepLink, + Route: tc.route, + } + + gotBaseURLWithRoute, err := wdl.BaseURLWithRoute() + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantResult, gotBaseURLWithRoute) + }) + } +} + +func Test_WalletDeepLink_TomlFileDomain(t *testing.T) { + testCases := []struct { + link string + wantResult string + wantErr error + }{ + { + link: "", + wantResult: "", + wantErr: fmt.Errorf("AnchorPlatformBaseSepURL can't be empty"), + }, + { + link: "test.com", + wantResult: "test.com", + wantErr: nil, + }, + { + link: "https://test.com", + wantResult: "test.com", + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.link, func(t *testing.T) { + wdl := WalletDeepLink{ + AnchorPlatformBaseSepURL: tc.link, + } + + result, err := wdl.TomlFileDomain() + require.Equal(t, tc.wantErr, err) + require.Equal(t, tc.wantResult, result) + }) + } +} + +func Test_WalletDeepLink_validate(t *testing.T) { + // wallet schema can't be empty + wdl := WalletDeepLink{} + err := wdl.validate() + require.EqualError(t, err, "wallet schema can't be empty") + + // we need a host, a path or a route + wdl.DeepLink = "wallet://" + err = wdl.validate() + require.EqualError(t, err, "can't generate a valid base URL for the deep link: the deep link needs to have a valid host, or path, or route") + + // toml file domain can't be empty + wdl.DeepLink = "wallet://sdp" + err = wdl.validate() + require.EqualError(t, err, "toml file domain can't be empty") + + // toml file domain can't be empty (different setup) + wdl.DeepLink = "wallet://" + wdl.Route = "sdp" + err = wdl.validate() + require.EqualError(t, err, "toml file domain can't be empty") + + // organization name can't be empty + wdl.AnchorPlatformBaseSepURL = "foo.bar" + err = wdl.validate() + require.EqualError(t, err, "organization name can't be empty") + + // asset code can't be empty + wdl.OrganizationName = "Foo Bar Org" + err = wdl.validate() + require.EqualError(t, err, "asset code can't be empty") + + // asset issuer can't be empty if it's not native (XLM) + wdl.AssetCode = "FOO" + err = wdl.validate() + require.EqualError(t, err, "asset issuer can't be empty unless the asset code is XLM") + + // asset issuer needs to be a valid Ed25519PublicKey + wdl.AssetIssuer = "BAR" + err = wdl.validate() + require.EqualError(t, err, "asset issuer is not a valid Ed25519 public key BAR") + + // Successful for non-native assets πŸŽ‰ + wdl.AssetIssuer = "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX" + err = wdl.validate() + require.NoError(t, err) + + // asset issuer needs to be empty if it's native (XLM) + wdl.AssetCode = "XLM" + wdl.AssetIssuer = "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX" + err = wdl.validate() + require.EqualError(t, err, "asset issuer should be empty for XLM, but is GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX") + + // Successful for native (XLM) assets πŸŽ‰ + wdl.AssetIssuer = "" + err = wdl.validate() + require.NoError(t, err) +} + +func Test_WalletDeepLink_GetUnsignedRegistrationLink(t *testing.T) { + testCases := []struct { + name string + walletDeepLink WalletDeepLink + wantResult string + wantErrContains string + }{ + { + name: "returns error if WalletDeepLink validation fails", + wantErrContains: "validating WalletDeepLink: wallet schema can't be empty", + }, + { + name: "πŸŽ‰ successful for non-native assets", + walletDeepLink: WalletDeepLink{ + DeepLink: "wallet://", + Route: "sdp", // route added separated from the deep link + AnchorPlatformBaseSepURL: "foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "FOO", + AssetIssuer: "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX", + }, + wantResult: "wallet://sdp?asset=FOO-GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX&domain=foo.bar&name=Foo+Bar+Org", + }, + { + name: "πŸŽ‰ successful for native (XLM) assets", + walletDeepLink: WalletDeepLink{ + DeepLink: "wallet://sdp", // route added directly to the deep link + AnchorPlatformBaseSepURL: "foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "XLM", + }, + wantResult: "wallet://sdp?asset=XLM&domain=foo.bar&name=Foo+Bar+Org", + }, + { + name: "πŸŽ‰ successful for deeplink with query params", + walletDeepLink: WalletDeepLink{ + DeepLink: "wallet://sdp?custom=true", + AnchorPlatformBaseSepURL: "foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "FOO", + AssetIssuer: "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX", + }, + wantResult: "wallet://sdp?asset=FOO-GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX&custom=true&domain=foo.bar&name=Foo+Bar+Org", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotResult, err := tc.walletDeepLink.GetUnsignedRegistrationLink() + if tc.wantErrContains != "" { + assert.Empty(t, gotResult) + assert.Contains(t, err.Error(), tc.wantErrContains) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantResult, gotResult) + } + }) + } +} + +func Test_WalletDeepLink_GetSignedRegistrationLink(t *testing.T) { + stellarPublicKey := "GBFDUUZ5ZYC6RAPOQLM7IYXLFHYTMCYXBGM7NIC4EE2MWOSGIYCOSN5F" + stellarSecretKey := "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5" + + t.Run("fails if something is wrong with the WalletDeepLink object", func(t *testing.T) { + wdl := WalletDeepLink{} + actual, err := wdl.GetSignedRegistrationLink(stellarSecretKey) + require.Empty(t, actual) + require.EqualError(t, err, "error getting unsigned registration link: validating WalletDeepLink: wallet schema can't be empty") + }) + + t.Run("fails if the private key is invalid", func(t *testing.T) { + wdl := WalletDeepLink{ + DeepLink: "wallet://", + Route: "sdp", + AnchorPlatformBaseSepURL: "foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "FOO", + AssetIssuer: "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX", + } + + actual, err := wdl.GetSignedRegistrationLink("invalid-secret-key") + require.Empty(t, actual) + require.EqualError(t, err, "error signing registration link: error parsing stellar private key: base32 decode failed: illegal base32 data at input byte 18") + }) + + t.Run("Successful for non-native assets πŸŽ‰", func(t *testing.T) { + wdl := WalletDeepLink{ + DeepLink: "wallet://sdp", + AnchorPlatformBaseSepURL: "foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "FOO", + AssetIssuer: "GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX", + } + + expected := "wallet://sdp?asset=FOO-GCKGCKZ2PFSCRQXREJMTHAHDMOZQLS2R4V5LZ6VLU53HONH5FI6ACBSX&domain=foo.bar&name=Foo+Bar+Org&signature=361b0c0e6094dc35e0baa8ccae99bac1bdddc099e8bf6f68f4045e15b99c96d1a39c5343bb010a0b34f29a3490d233d43e3e2f5e537cf52d85f62deb75b2150d" + actual, err := wdl.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + require.Equal(t, expected, actual) + + isValid, err := utils.VerifySignedURL(actual, stellarPublicKey) + require.NoError(t, err) + require.True(t, isValid) + }) + + t.Run("Successful for native (XLM) assets πŸŽ‰", func(t *testing.T) { + wdl := WalletDeepLink{ + DeepLink: "wallet://", + Route: "sdp", + AnchorPlatformBaseSepURL: "foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "XLM", + } + + expected := "wallet://sdp?asset=XLM&domain=foo.bar&name=Foo+Bar+Org&signature=d3ffb7c9f78d2131b5be4e3a1302cfe87685706e36f6f1115e4b28bb940cc75532d56ab1d5c5f3481f210021811510290735858ea35b88e26cd5a115f7ea450b" + actual, err := wdl.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + require.Equal(t, expected, actual) + + isValid, err := utils.VerifySignedURL(actual, stellarPublicKey) + require.NoError(t, err) + require.True(t, isValid) + }) + + t.Run("Successful for native (XLM) assets and AnchorPlatformBaseSepURL with https:// schema πŸŽ‰", func(t *testing.T) { + wdl := WalletDeepLink{ + DeepLink: "wallet://sdp", + AnchorPlatformBaseSepURL: "https://foo.bar", + OrganizationName: "Foo Bar Org", + AssetCode: "XLM", + } + + expected := "wallet://sdp?asset=XLM&domain=foo.bar&name=Foo+Bar+Org&signature=d3ffb7c9f78d2131b5be4e3a1302cfe87685706e36f6f1115e4b28bb940cc75532d56ab1d5c5f3481f210021811510290735858ea35b88e26cd5a115f7ea450b" + actual, err := wdl.GetSignedRegistrationLink(stellarSecretKey) + require.NoError(t, err) + require.Equal(t, expected, actual) + + isValid, err := utils.VerifySignedURL(actual, stellarPublicKey) + require.NoError(t, err) + require.True(t, isValid) + }) +} diff --git a/internal/services/setup_assets_for_network_service.go b/internal/services/setup_assets_for_network_service.go new file mode 100644 index 000000000..c5270bf98 --- /dev/null +++ b/internal/services/setup_assets_for_network_service.go @@ -0,0 +1,119 @@ +package services + +import ( + "context" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type AssetsNetworkMapType map[utils.NetworkType]map[string]string + +var DefaultAssetsNetworkMap = AssetsNetworkMapType{ + utils.PubnetNetworkType: { + "USDC": "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVN", + }, + utils.TestnetNetworkType: { + "USDC": "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + }, +} + +// SetupAssetsForProperNetwork updates and inserts assets for the given Network Passphrase (`network`). So it avoids the application having +// same asset code with multiple issuers. +func SetupAssetsForProperNetwork(ctx context.Context, dbConnectionPool db.DBConnectionPool, network utils.NetworkType, assetsNetworkMap AssetsNetworkMapType) error { + log.Ctx(ctx).Infof("updating/inserting assets for the '%s' network\n\n", network) + + assets, ok := assetsNetworkMap[network] + if !ok { + return fmt.Errorf("invalid network provided") + } + + var codes, issuers []string + + separator := strings.Repeat("-", 20) + buf := new(strings.Builder) + buf.WriteString("assets' code that will be updated or inserted:\n\n") + for code, issuer := range assets { + codes = append(codes, code) + issuers = append(issuers, issuer) + + buf.WriteString(fmt.Sprintf("Code: %s\n%s\n\n", code, separator)) + } + + log.Ctx(ctx).Info(buf.String()) + + err := db.RunInTransaction(ctx, dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + query := ` + WITH assets_to_update_or_insert AS ( + -- gather all assets passed as parameters for the query and turn into SQL rows + SELECT UNNEST($1::text[]) AS code, UNNEST($2::text[]) AS issuer + ), + existing_assets AS ( + -- gets all assets that the code appears in the codes passed as parameter for the query + SELECT + * + FROM + assets + WHERE + code = ANY($1::text[]) + FOR UPDATE + ), + update_existing_assets AS ( + -- updates the existing assets resulted in 'existing_assets' CTE + UPDATE + assets a + SET + issuer = atui.issuer + FROM + existing_assets ea + INNER JOIN assets_to_update_or_insert atui ON ea.code = atui.code + WHERE + a.id = ea.id AND a.issuer != atui.issuer + ) + -- inserts assets in the database + INSERT INTO assets + (code, issuer) + SELECT + atui.code, atui.issuer + FROM + assets_to_update_or_insert atui + WHERE + atui.code NOT IN (SELECT code FROM existing_assets) + ` + + _, err := dbTx.ExecContext(ctx, query, pq.Array(codes), pq.Array(issuers)) + if err != nil { + return fmt.Errorf("error upserting assets: %w", err) + } + + return nil + }) + if err != nil { + return fmt.Errorf("error upserting assets for the proper network: %w", err) + } + + models, err := data.NewModels(dbConnectionPool) + if err != nil { + return fmt.Errorf("error getting models: %w", err) + } + + allAssets, err := models.Assets.GetAll(ctx) + if err != nil { + return fmt.Errorf("error getting all available assets on database: %w", err) + } + + buf.Reset() + buf.WriteString(fmt.Sprintf("Registered assets for network %s:\n\n", network)) + for _, asset := range allAssets { + buf.WriteString(fmt.Sprintf("Code: %s\nIssuer: %s\n%s\n\n", asset.Code, asset.Issuer, separator)) + } + + log.Ctx(ctx).Info(buf.String()) + + return nil +} diff --git a/internal/services/setup_assets_for_network_service_test.go b/internal/services/setup_assets_for_network_service_test.go new file mode 100644 index 000000000..e07a85707 --- /dev/null +++ b/internal/services/setup_assets_for_network_service_test.go @@ -0,0 +1,183 @@ +package services + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SetupAssetsForProperNetwork(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("returns error when a invalid network is set", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + err := SetupAssetsForProperNetwork(ctx, dbConnectionPool, "invalid", DefaultAssetsNetworkMap) + assert.EqualError(t, err, "invalid network provided") + }) + + t.Run("inserts new assets when it doesn't exist", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + buf := new(strings.Builder) + log.DefaultLogger.SetLevel(log.InfoLevel) + log.DefaultLogger.SetOutput(buf) + + err := SetupAssetsForProperNetwork(ctx, dbConnectionPool, utils.TestnetNetworkType, DefaultAssetsNetworkMap) + require.NoError(t, err) + + assets, err := models.Assets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, assets, 1) + assert.Equal(t, "USDC", assets[0].Code) + assert.Equal(t, "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", assets[0].Issuer) + + expectedLogs := []string{ + "updating/inserting assets for the 'testnet' network", + "Code: USDC", + "Issuer: GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + } + + logs := buf.String() + for _, expectedLog := range expectedLogs { + assert.Contains(t, logs, expectedLog) + } + }) + + t.Run("updates and inserts assets", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + pubnetEUROCIssuer := keypair.MustRandom().Address() + data.CreateAssetFixture(t, ctx, dbConnectionPool, "EUROC", pubnetEUROCIssuer) + + testnetUSDCIssuer := keypair.MustRandom().Address() + testnetEUROCIssuer := keypair.MustRandom().Address() + + assert.NotEqual(t, testnetEUROCIssuer, pubnetEUROCIssuer) + + assets, err := models.Assets.GetAll(ctx) + require.NoError(t, err) + assert.Len(t, assets, 1) + assert.Equal(t, "EUROC", assets[0].Code) + assert.Equal(t, pubnetEUROCIssuer, assets[0].Issuer) + + assetsNetworkMap := AssetsNetworkMapType{ + utils.TestnetNetworkType: { + "EUROC": testnetEUROCIssuer, + "USDC": testnetUSDCIssuer, + }, + } + + buf := new(strings.Builder) + log.DefaultLogger.SetLevel(log.InfoLevel) + log.DefaultLogger.SetOutput(buf) + + err = SetupAssetsForProperNetwork(ctx, dbConnectionPool, utils.TestnetNetworkType, assetsNetworkMap) + require.NoError(t, err) + + assets, err = models.Assets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, assets, 2) + assert.Equal(t, "EUROC", assets[0].Code) + assert.Equal(t, testnetEUROCIssuer, assets[0].Issuer) + assert.Equal(t, "USDC", assets[1].Code) + assert.Equal(t, testnetUSDCIssuer, assets[1].Issuer) + + expectedLogs := []string{ + "updating/inserting assets for the 'testnet' network", + "Code: EUROC", + "Code: USDC", + fmt.Sprintf("Issuer: %s", testnetEUROCIssuer), + fmt.Sprintf("Issuer: %s", testnetEUROCIssuer), + } + + logs := buf.String() + for _, expectedLog := range expectedLogs { + assert.Contains(t, logs, expectedLog) + } + }) + + t.Run("doesn't change the asset when it's not in the assetsNetworkMap", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + testnetEUROCIssuer := keypair.MustRandom().Address() + data.CreateAssetFixture(t, ctx, dbConnectionPool, "EUROC", testnetEUROCIssuer) + + pubnetARSTIssuer := keypair.MustRandom().Address() + data.CreateAssetFixture(t, ctx, dbConnectionPool, "ARST", pubnetARSTIssuer) + + pubnetUSDCIssuer := keypair.MustRandom().Address() + pubnetEUROCIssuer := keypair.MustRandom().Address() + + assert.NotEqual(t, testnetEUROCIssuer, pubnetEUROCIssuer) + + assets, err := models.Assets.GetAll(ctx) + require.NoError(t, err) + assert.Len(t, assets, 2) + assert.Equal(t, "ARST", assets[0].Code) + assert.Equal(t, pubnetARSTIssuer, assets[0].Issuer) + assert.Equal(t, "EUROC", assets[1].Code) + assert.Equal(t, testnetEUROCIssuer, assets[1].Issuer) + + assetsNetworkMap := AssetsNetworkMapType{ + utils.PubnetNetworkType: { + "EUROC": pubnetEUROCIssuer, + "USDC": pubnetUSDCIssuer, + }, + } + + buf := new(strings.Builder) + log.DefaultLogger.SetLevel(log.InfoLevel) + log.DefaultLogger.SetOutput(buf) + + err = SetupAssetsForProperNetwork(ctx, dbConnectionPool, utils.PubnetNetworkType, assetsNetworkMap) + require.NoError(t, err) + + assets, err = models.Assets.GetAll(ctx) + require.NoError(t, err) + assert.Len(t, assets, 3) + assert.Equal(t, "ARST", assets[0].Code) + assert.Equal(t, pubnetARSTIssuer, assets[0].Issuer) + assert.Equal(t, "EUROC", assets[1].Code) + assert.Equal(t, pubnetEUROCIssuer, assets[1].Issuer) + assert.Equal(t, "USDC", assets[2].Code) + assert.Equal(t, pubnetUSDCIssuer, assets[2].Issuer) + + expectedLogs := []string{ + "updating/inserting assets for the 'pubnet' network", + "Code: ARST", + "Code: EUROC", + "Code: USDC", + fmt.Sprintf("Issuer: %s", pubnetARSTIssuer), + fmt.Sprintf("Issuer: %s", pubnetEUROCIssuer), + fmt.Sprintf("Issuer: %s", pubnetUSDCIssuer), + } + + logs := buf.String() + for _, expectedLog := range expectedLogs { + assert.Contains(t, logs, expectedLog) + } + }) +} diff --git a/internal/services/setup_wallets_for_network_service.go b/internal/services/setup_wallets_for_network_service.go new file mode 100644 index 000000000..b9a710d97 --- /dev/null +++ b/internal/services/setup_wallets_for_network_service.go @@ -0,0 +1,145 @@ +package services + +import ( + "context" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +type WalletInfo struct { + Name, Homepage, DeepLinkSchema, SEP10ClientDomain string +} + +type WalletsNetworkMapType map[utils.NetworkType][]WalletInfo + +var DefaultWalletsNetworkMap = WalletsNetworkMapType{ + utils.PubnetNetworkType: { + { + Name: "Vibrant Assist", + Homepage: "https://vibrantapp.com/assist", + DeepLinkSchema: "https://vibrantapp.com/sdp", + SEP10ClientDomain: "api.vibrantapp.com", + }, + // { + // Name: "Beans App", + // Homepage: "https://www.beansapp.com/disbursements", + // DeepLinkSchema: "https://www.beansapp.com/disbursements/registration?redirect=true", + // SEP10ClientDomain: "api.beansapp.com", + // }, + }, + utils.TestnetNetworkType: { + { + Name: "Demo Wallet", + Homepage: "https://demo-wallet.stellar.org", + DeepLinkSchema: "https://demo-wallet.stellar.org", + SEP10ClientDomain: "demo-wallet-server.stellar.org", + }, + }, +} + +// SetupWalletsForProperNetwork updates and inserts wallets for the given Network Passphrase (`network`). So it avoids the application having +// wallets that doesn't support the given network. +func SetupWalletsForProperNetwork(ctx context.Context, dbConnectionPool db.DBConnectionPool, network utils.NetworkType, walletsNetworkMap WalletsNetworkMapType) error { + log.Ctx(ctx).Infof("updating/inserting wallets for the '%s' network\n\n", network) + + wallets, ok := walletsNetworkMap[network] + if !ok { + return fmt.Errorf("invalid network provided") + } + + var names, homepages, deepLinkSchemas, sep10ClientDomains []string + + separator := strings.Repeat("-", 20) + buf := new(strings.Builder) + buf.WriteString("wallets that will be updated or inserted:\n\n") + for _, wallet := range wallets { + names = append(names, wallet.Name) + homepages = append(homepages, wallet.Homepage) + deepLinkSchemas = append(deepLinkSchemas, wallet.DeepLinkSchema) + sep10ClientDomains = append(sep10ClientDomains, wallet.SEP10ClientDomain) + + buf.WriteString(fmt.Sprintf("%s\n%s\n\n", wallet.Name, separator)) + } + + log.Ctx(ctx).Info(buf.String()) + + err := db.RunInTransaction(ctx, dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + query := ` + WITH wallets_to_update_or_insert AS ( + -- gather all wallets passed as parameters for the query and turn into SQL rows + SELECT + UNNEST($1::text[]) AS name, UNNEST($2::text[]) AS homepage, + UNNEST($3::text[]) AS deep_link_schema, UNNEST($4::text[]) AS sep_10_client_domain + ), + existing_wallets AS ( + -- gets all wallets that the name appears in the names passed as parameter for the query + SELECT + * + FROM + wallets + WHERE + name = ANY($1::text[]) + FOR UPDATE + ), + update_existing_wallets AS ( + -- updates the existing wallets resulted in 'existing_wallets' CTE + UPDATE + wallets w + SET + homepage = wtui.homepage, + deep_link_schema = wtui.deep_link_schema, + sep_10_client_domain = wtui.sep_10_client_domain + FROM + existing_wallets ew + INNER JOIN wallets_to_update_or_insert wtui ON ew.name = wtui.name + WHERE + w.id = ew.id + ) + -- inserts wallets in the database + INSERT INTO wallets + (name, homepage, deep_link_schema, sep_10_client_domain) + SELECT + wtui.name, wtui.homepage, wtui.deep_link_schema, wtui.sep_10_client_domain + FROM + wallets_to_update_or_insert wtui + WHERE + wtui.name NOT IN (SELECT name FROM existing_wallets) + ` + + _, err := dbTx.ExecContext(ctx, query, pq.Array(names), pq.Array(homepages), pq.Array(deepLinkSchemas), pq.Array(sep10ClientDomains)) + if err != nil { + return fmt.Errorf("error upserting wallets: %w", err) + } + + return nil + }) + if err != nil { + return fmt.Errorf("error upserting wallets for the proper network: %w", err) + } + + models, err := data.NewModels(dbConnectionPool) + if err != nil { + return fmt.Errorf("error getting models: %w", err) + } + + allWallets, err := models.Wallets.GetAll(ctx) + if err != nil { + return fmt.Errorf("error getting all available wallets on database: %w", err) + } + + buf.Reset() + buf.WriteString(fmt.Sprintf("Registered wallets for network %s:\n\n", network)) + for _, wallet := range allWallets { + buf.WriteString(fmt.Sprintf("Name: %s\nHomepage: %s\nDeep Link Schema: %s\nSEP-10 Client Domain: %s\n%s\n\n", wallet.Name, wallet.Homepage, wallet.DeepLinkSchema, wallet.SEP10ClientDomain, separator)) + } + + log.Ctx(ctx).Info(buf.String()) + + return nil +} diff --git a/internal/services/setup_wallets_for_network_service_test.go b/internal/services/setup_wallets_for_network_service_test.go new file mode 100644 index 000000000..ef6b5431a --- /dev/null +++ b/internal/services/setup_wallets_for_network_service_test.go @@ -0,0 +1,176 @@ +package services + +import ( + "context" + "strings" + "testing" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SetupWalletsForProperNetwork(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + ctx := context.Background() + t.Run("returns error when a invalid network is set", func(t *testing.T) { + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + + err = SetupWalletsForProperNetwork(ctx, dbConnectionPool, "invalid", DefaultWalletsNetworkMap) + assert.EqualError(t, err, "invalid network provided") + }) + + t.Run("inserts new wallets when it doesn't exist", func(t *testing.T) { + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + buf := new(strings.Builder) + log.DefaultLogger.SetLevel(log.InfoLevel) + log.DefaultLogger.SetOutput(buf) + + err = SetupWalletsForProperNetwork(ctx, dbConnectionPool, utils.PubnetNetworkType, DefaultWalletsNetworkMap) + require.NoError(t, err) + + wallets, err := models.Wallets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, wallets, 1) + // assert.Equal(t, "Beans App", wallets[0].Name) + assert.Equal(t, "Vibrant Assist", wallets[0].Name) + + expectedLogs := []string{ + "updating/inserting wallets for the 'pubnet' network", + // "Name: Beans App", + // "Homepage: https://www.beansapp.com/disbursements", + // "Deep Link Schema: https://www.beansapp.com/disbursements/registration?redirect=true", + // "SEP-10 Client Domain: api.beansapp.com", + "Name: Vibrant Assist", + "Homepage: https://vibrantapp.com/assist", + "Deep Link Schema: https://vibrantapp.com/sdp", + "SEP-10 Client Domain: api.vibrantapp.com", + } + + logs := buf.String() + for _, expectedLog := range expectedLogs { + assert.Contains(t, logs, expectedLog) + } + }) + + t.Run("updates and inserts wallets", func(t *testing.T) { + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + + data.CreateWalletFixture(t, ctx, dbConnectionPool, "Vibrant Assist", "https://vibrantapp.com", "api-dev.vibrantapp.com", "https://vibrantapp.com/sdp-dev") + + wallets, err := models.Wallets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, wallets, 1) + assert.Equal(t, "Vibrant Assist", wallets[0].Name) + assert.Equal(t, "https://vibrantapp.com", wallets[0].Homepage) + assert.Equal(t, "api-dev.vibrantapp.com", wallets[0].SEP10ClientDomain) + assert.Equal(t, "https://vibrantapp.com/sdp-dev", wallets[0].DeepLinkSchema) + + walletsNetworkMap := WalletsNetworkMapType{ + utils.PubnetNetworkType: { + { + Name: "Vibrant Assist", + Homepage: "https://vibrantapp.com/vibrant-assist", + DeepLinkSchema: "https://aidpubnet.netlify.app", + SEP10ClientDomain: "api.vibrantapp.com", + }, + { + Name: "BOSS Money", + Homepage: "https://www.walletbyboss.com", + DeepLinkSchema: "https://www.walletbyboss.com", + SEP10ClientDomain: "www.walletbyboss.com", + }, + }, + } + + buf := new(strings.Builder) + log.DefaultLogger.SetLevel(log.InfoLevel) + log.DefaultLogger.SetOutput(buf) + + err = SetupWalletsForProperNetwork(ctx, dbConnectionPool, utils.PubnetNetworkType, walletsNetworkMap) + require.NoError(t, err) + + wallets, err = models.Wallets.GetAll(ctx) + require.NoError(t, err) + + assert.Len(t, wallets, 2) + assert.Equal(t, "BOSS Money", wallets[0].Name) + assert.Equal(t, "https://www.walletbyboss.com", wallets[0].Homepage) + assert.Equal(t, "www.walletbyboss.com", wallets[0].SEP10ClientDomain) + assert.Equal(t, "https://www.walletbyboss.com", wallets[0].DeepLinkSchema) + + assert.Equal(t, "Vibrant Assist", wallets[1].Name) + assert.Equal(t, "https://vibrantapp.com/vibrant-assist", wallets[1].Homepage) + assert.Equal(t, "api.vibrantapp.com", wallets[1].SEP10ClientDomain) + assert.Equal(t, "https://aidpubnet.netlify.app", wallets[1].DeepLinkSchema) + + expectedLogs := []string{ + "updating/inserting wallets for the 'pubnet' network", + "Name: BOSS Money", + "Homepage: https://www.walletbyboss.com", + "Deep Link Schema: https://www.walletbyboss.com", + "SEP-10 Client Domain: www.walletbyboss.com", + "Name: Vibrant Assist", + "Homepage: https://vibrantapp.com/vibrant-assist", + "Deep Link Schema: https://aidpubnet.netlify.app", + "SEP-10 Client Domain: api.vibrantapp.com", + } + + logs := buf.String() + for _, expectedLog := range expectedLogs { + assert.Contains(t, logs, expectedLog) + } + }) + + // Ensure the BOSS Money bug doesn't happen again on Testnet. Please refer to: https://stellarfoundation.slack.com/archives/C018BLTP2AU/p1686690282162189 + t.Run("duplicated constraint error", func(t *testing.T) { + // creates the Vibrant Assist and BOSS Money wallets + data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool) + + walletNetworkMap := WalletsNetworkMapType{ + utils.TestnetNetworkType: { + { + Name: "Boss Money", + Homepage: "https://www.walletbyboss.com", + DeepLinkSchema: "https://www.walletbyboss.com", + SEP10ClientDomain: "www.walletbyboss.com", + }, + { + Name: "Vibrant Assist", + Homepage: "https://vibrantapp.com", + DeepLinkSchema: "https://vibrantapp.com/sdp-dev", + SEP10ClientDomain: "api-dev.vibrantapp.com", + }, + }, + } + + err := SetupWalletsForProperNetwork(ctx, dbConnectionPool, utils.TestnetNetworkType, walletNetworkMap) + + // The problem was that in the DefaultWalletsNetworkMap, in the `testnet` key, we used the name `Boss Money` and not `BOSS Money` + // to refer to the BOSS Money wallet. So the query tried to insert the `Boss Money` wallet, but since the `homepage` and `deep_link_schema` + // were the same as the already inserted then, the insert statement resulted in a duplicated constraint error. + assert.EqualError(t, err, `error upserting wallets for the proper network: running atomic function in RunInTransactionWithResult: error upserting wallets: pq: duplicate key value violates unique constraint "wallets_homepage_key"`) + + // DefaultNetworkMap test - should NOT error + data.ClearAndCreateWalletFixtures(t, ctx, dbConnectionPool) + + err = SetupWalletsForProperNetwork(ctx, dbConnectionPool, utils.TestnetNetworkType, DefaultWalletsNetworkMap) + require.NoError(t, err) + }) +} diff --git a/internal/services/tss_monitor_service.go b/internal/services/tss_monitor_service.go new file mode 100644 index 000000000..8b862463a --- /dev/null +++ b/internal/services/tss_monitor_service.go @@ -0,0 +1,150 @@ +package services + +import ( + "context" + "fmt" + + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + txSubStore "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" +) + +type TSSMonitorService struct { + sdpModels *data.Models + tssModel *txSubStore.TransactionModel +} + +// MonitorTransactions monitors TSS transactions and updates payments accordingly. +func (s TSSMonitorService) MonitorTransactions(ctx context.Context, batchSize int) error { + err := db.RunInTransaction(ctx, s.sdpModels.DBConnectionPool, nil, func(dbTx db.DBTransaction) error { + return s.monitorTransactions(ctx, dbTx, batchSize) + }) + if err != nil { + return fmt.Errorf("error sending payments: %w", err) + } + + return nil +} + +// monitorTransactions monitors TSS transactions and updates payments accordingly. +func (s TSSMonitorService) monitorTransactions(ctx context.Context, dbTx db.DBTransaction, batchSize int) error { + // 1. Get transactions that are in a final state (status=SUCCESS or status=ERROR) + // this operation will lock the rows. + transactions, err := s.tssModel.GetTransactionBatchForUpdate(ctx, dbTx, batchSize) + if err != nil { + return fmt.Errorf("error getting transactions for update: %w", err) + } + if len(transactions) == 0 { + log.Ctx(ctx).Infof("No transactions to sync") + return nil + } + + // 2. Split transactions into successful and failed + failedTransactions := []*txSubStore.Transaction{} + successfulTransactions := []*txSubStore.Transaction{} + for _, transaction := range transactions { + if !transaction.StellarTransactionHash.Valid { + return fmt.Errorf("expected transaction %s to have a stellar transaction hash", transaction.ID) + } + if transaction.Status == txSubStore.TransactionStatusSuccess { + successfulTransactions = append(successfulTransactions, transaction) + } else if transaction.Status == txSubStore.TransactionStatusError { + failedTransactions = append(failedTransactions, transaction) + } else { + return fmt.Errorf("transaction id %s is in an unexpected status: %s", transaction.ID, transaction.Status) + } + } + + // 3. Update payments based on the status of the transactions + if len(successfulTransactions) > 0 { + log.Ctx(ctx).Infof("Syncing payments for %d successful transactions", len(successfulTransactions)) + errPayments := s.syncPaymentsWithTransactions(ctx, dbTx, successfulTransactions, data.SuccessPaymentStatus) + if errPayments != nil { + return fmt.Errorf("error syncing payments for successful transactions: %w", errPayments) + } + } + if len(failedTransactions) > 0 { + log.Ctx(ctx).Infof("Syncing payments for %d failed transactions", len(failedTransactions)) + errPayments := s.syncPaymentsWithTransactions(ctx, dbTx, failedTransactions, data.FailedPaymentStatus) + if errPayments != nil { + return fmt.Errorf("error syncing payments for failed transactions: %w", errPayments) + } + } + + // 4. Set synced_at for all synced transactions + transactionIDs := make([]string, len(transactions)) + for i, transaction := range transactions { + transactionIDs[i] = transaction.ID + } + err = s.tssModel.UpdateSyncedTransactions(ctx, dbTx, transactionIDs) + if err != nil { + return fmt.Errorf("error updating transactions as synced: %w", err) + } + + return nil +} + +// syncPaymentsWithTransactions updates the status of the payments based on the status of the transactions. +func (s TSSMonitorService) syncPaymentsWithTransactions(ctx context.Context, dbTx db.DBTransaction, transactions []*txSubStore.Transaction, toStatus data.PaymentStatus) error { + paymentIDs := make([]string, len(transactions)) + for i, transaction := range transactions { + paymentIDs[i] = transaction.ExternalID + } + payments, errPayments := s.sdpModels.Payment.GetByIDs(ctx, dbTx, paymentIDs) + if errPayments != nil { + return fmt.Errorf("error getting payments by ids: %w", errPayments) + } + + // Create a map of disbursement id from payment + disbursementMap := make(map[string]struct{}, len(payments)) + paymentMap := make(map[string]*data.Payment, len(payments)) + + for _, payment := range payments { + if payment.Status != data.PendingPaymentStatus { + return fmt.Errorf("error getting payments by ids: expected payment %s to be in pending status but got %s", payment.ID, payment.Status) + } + paymentMap[payment.ID] = payment + disbursementMap[payment.Disbursement.ID] = struct{}{} + } + + // Update payment status for each transaction to SUCCESS or FAILURE + for _, transaction := range transactions { + payment := paymentMap[transaction.ExternalID] + if payment == nil { + // The payment associated with this transaction was deleted. + log.Ctx(ctx).Errorf("orphaned transaction - Unable to sync transaction %s because the associated payment %s was deleted", + transaction.ID, + transaction.ExternalID) + continue + } + paymentUpdate := &data.PaymentUpdate{ + Status: toStatus, + StatusMessage: transaction.StatusMessage.String, + StellarTransactionID: transaction.StellarTransactionHash.String, + } + errUpdate := s.sdpModels.Payment.Update(ctx, dbTx, payment, paymentUpdate) + if errUpdate != nil { + return fmt.Errorf("error updating payment id %s for transaction id %s: %w", payment.ID, transaction.ID, errUpdate) + } + } + + disbursementIDs := make([]string, 0, len(disbursementMap)) + for disbursement := range disbursementMap { + disbursementIDs = append(disbursementIDs, disbursement) + } + err := s.sdpModels.Disbursements.CompleteDisbursements(ctx, dbTx, disbursementIDs) + if err != nil { + return fmt.Errorf("error completing disbursement: %w", err) + } + + return nil +} + +// NewTSSMonitorService creates a new TSSMonitorService instance. +func NewTSSMonitorService(models *data.Models) *TSSMonitorService { + return &TSSMonitorService{ + sdpModels: models, + tssModel: txSubStore.NewTransactionModel(models.DBConnectionPool), + } +} diff --git a/internal/services/tss_monitor_service_test.go b/internal/services/tss_monitor_service_test.go new file mode 100644 index 000000000..db92b5cb4 --- /dev/null +++ b/internal/services/tss_monitor_service_test.go @@ -0,0 +1,588 @@ +package services + +import ( + "context" + "database/sql" + "fmt" + "strings" + "testing" + + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" + + "github.com/lib/pq" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + txSubStore "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stretchr/testify/require" +) + +type testContext struct { + tssModel *txSubStore.TransactionModel + sdpModel *data.Models + ctx context.Context +} + +func setupTestContext(t *testing.T, dbConnectionPool db.DBConnectionPool) *testContext { + t.Helper() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + tssModel := txSubStore.NewTransactionModel(models.DBConnectionPool) + + return &testContext{ + tssModel: tssModel, + sdpModel: models, + ctx: context.Background(), + } +} + +func Test_TSSMonitorService_MonitorTransactions(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + testCtx := setupTestContext(t, dbConnectionPool) + ctx := testCtx.ctx + + paymentService := NewSendPaymentsService(testCtx.sdpModel) + monitorService := NewTSSMonitorService(testCtx.sdpModel) + + // create fixtures + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, + "My Wallet", + "https://www.wallet.com", + "www.wallet.com", + "wallet1://") + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, + "USDC", + "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, + "FRA", + "France") + + // create disbursements + startedDisbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{ + Name: "ready disbursement", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + // create disbursement receivers + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver3 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiver4 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + + rw1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rw2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rw3 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver3.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + rw4 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver4.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + payment1 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + ReceiverWallet: rw1, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "100", + Status: data.ReadyPaymentStatus, + }) + payment2 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + ReceiverWallet: rw2, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "200", + Status: data.ReadyPaymentStatus, + }) + payment3 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + ReceiverWallet: rw3, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "300", + Status: data.ReadyPaymentStatus, + }) + payment4 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + ReceiverWallet: rw4, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "400", + Status: data.ReadyPaymentStatus, + }) + payment5 := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + ReceiverWallet: rw4, + Disbursement: startedDisbursement, + Asset: *asset, + Amount: "400", + Status: data.ReadyPaymentStatus, + }) + + outerErr = paymentService.SendBatchPayments(ctx, 5) + require.NoError(t, outerErr) + + transactions, outerErr := testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment1.ID, payment2.ID, payment3.ID, payment4.ID, payment5.ID}) + require.NoError(t, outerErr) + require.Len(t, transactions, 5) + + // Update Hash and status of transactions to simulate success + prepareTxsForSync(t, testCtx, transactions) + + // Fail the last transaction + updatedTransactions := updateTSSTransactionsToError(t, testCtx, []payloadToUpdateTSSTxToError{ + {transactionID: transactions[3].ID, statusMessages: "test-error"}, + {transactionID: transactions[4].ID, statusMessages: "another-test-error"}, + }) + require.Len(t, updatedTransactions, 2) + for _, updatedTransaction := range updatedTransactions { + utx := updatedTransaction + for i, transaction := range transactions { + if updatedTransaction.ID == transaction.ID { + transactions[i] = &utx + break + } + } + } + + t.Run("monitor successful tss transactions", func(t *testing.T) { + err := monitorService.MonitorTransactions(ctx, 5) + require.NoError(t, err) + + // check that successful payments are updated + for _, p := range []*data.Payment{payment1, payment2, payment3} { + payment, paymentErr := testCtx.sdpModel.Payment.Get(ctx, p.ID, dbConnectionPool) + require.NoError(t, paymentErr) + require.Equal(t, data.SuccessPaymentStatus, payment.Status) + txs, txErr := testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{p.ID}) + require.NoError(t, txErr) + require.Len(t, txs, 1) + require.Equal(t, fmt.Sprintf("test-hash-%s", txs[0].ID), payment.StellarTransactionID) + } + + // check that failed payment is updated + payment, paymentErr := testCtx.sdpModel.Payment.Get(ctx, payment4.ID, dbConnectionPool) + require.NoError(t, paymentErr) + require.Equal(t, data.FailedPaymentStatus, payment.Status) + txs, txErr := testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment4.ID}) + require.NoError(t, txErr) + require.Len(t, txs, 1) + require.Equal(t, fmt.Sprintf("test-hash-%s", txs[0].ID), payment.StellarTransactionID) + require.Len(t, payment.StatusHistory, 3) + require.Equal(t, payment.StatusHistory[2].Status, data.FailedPaymentStatus) + require.Equal(t, payment.StatusHistory[2].StatusMessage, "test-error") + + payment, paymentErr = testCtx.sdpModel.Payment.Get(ctx, payment5.ID, dbConnectionPool) + require.NoError(t, paymentErr) + require.Equal(t, data.FailedPaymentStatus, payment.Status) + txs, txErr = testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment5.ID}) + require.NoError(t, txErr) + require.Len(t, txs, 1) + require.Equal(t, fmt.Sprintf("test-hash-%s", txs[0].ID), payment.StellarTransactionID) + require.Len(t, payment.StatusHistory, 3) + require.Equal(t, payment.StatusHistory[2].Status, data.FailedPaymentStatus) + require.Equal(t, payment.StatusHistory[2].StatusMessage, "another-test-error") + + // validate transactions synced_at is updated. + txs, txErr = testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment1.ID, payment2.ID, payment3.ID, payment4.ID}) + require.NoError(t, txErr) + require.Len(t, txs, 4) + + for _, tx := range txs { + require.NotNil(t, tx.SyncedAt) + } + }) + + t.Run("error when hash is invalid", func(t *testing.T) { + prepareTxsForSync(t, testCtx, transactions) + q := `UPDATE submitter_transactions SET stellar_transaction_hash = '' WHERE id = $1` + _, err := dbConnectionPool.ExecContext(ctx, q, transactions[0].ID) + require.NoError(t, err) + + err = monitorService.MonitorTransactions(ctx, 5) + require.Error(t, err) + require.ErrorContainsf(t, err, "stellar transaction id is required", "error: %s", err.Error()) + }) + + t.Run("payment is not pending", func(t *testing.T) { + prepareTxsForSync(t, testCtx, transactions) + updatePaymentStatus(t, testCtx, payment1.ID, data.SuccessPaymentStatus) + + err := monitorService.MonitorTransactions(ctx, 5) + require.Error(t, err) + contains := fmt.Sprintf("expected payment %s to be in pending status but got SUCCESS", payment1.ID) + require.ErrorContainsf(t, err, contains, "error: %s", err.Error()) + }) + + t.Run("error for orphaned transactions", func(t *testing.T) { + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + prepareTxsForSync(t, testCtx, transactions) + // insert a transaction that is not associated with a payment + paymentID := "dummy_payment_id" + + tx, err := testCtx.tssModel.Insert(ctx, txSubStore.Transaction{ + ExternalID: paymentID, + AssetCode: asset.Code, + AssetIssuer: asset.Issuer, + Amount: 100, + Destination: rw1.StellarAddress, + }) + require.NoError(t, err) + + // Update transactions states PENDING->PROCESSING: + q := `UPDATE submitter_transactions SET stellar_transaction_hash = 'dummy_hash_123', status=$1 WHERE id = $2 RETURNING *` + err = dbConnectionPool.GetContext(ctx, tx, q, txSubStore.TransactionStatusProcessing, tx.ID) + require.NoError(t, err) + + tx, err = testCtx.tssModel.UpdateStatusToSuccess(ctx, *tx) + require.NoError(t, err) + assert.Equal(t, store.TransactionStatusSuccess, tx.Status) + assert.NotEmpty(t, tx.CompletedAt) + + err = monitorService.MonitorTransactions(ctx, 10) + require.NoError(t, err) + expectedError := fmt.Sprintf("orphaned transaction - Unable to sync transaction %s because the associated payment %s was deleted", tx.ID, paymentID) + assert.Contains(t, buf.String(), expectedError) + }) +} + +func prepareTxsForSync(t *testing.T, testCtx *testContext, transactions []*txSubStore.Transaction) { + t.Helper() + + txLen := len(transactions) + + var err error + + for _, tx := range transactions { + q := `UPDATE submitter_transactions SET stellar_transaction_hash = $1, status=$2 WHERE id = $3` + _, err = testCtx.tssModel.DBConnectionPool.ExecContext(testCtx.ctx, q, "test-hash-"+tx.ID, txSubStore.TransactionStatusProcessing, tx.ID) + require.NoError(t, err) + + tx, err = testCtx.tssModel.Get(testCtx.ctx, tx.ID) + + require.NoError(t, err) + + // Update transactions states PROCESSING->SUCCESS: + if tx.Status == txSubStore.TransactionStatusProcessing { + tx, err = testCtx.tssModel.UpdateStatusToSuccess(testCtx.ctx, *tx) + require.NoError(t, err) + assert.Equal(t, store.TransactionStatusSuccess, tx.Status) + assert.NotEmpty(t, tx.CompletedAt) + } + } + + transactionIDs := make([]string, txLen) + for i, tx := range transactions { + transactionIDs[i] = tx.ID + } + + unsyncTransactions(t, testCtx, transactionIDs) + + // Set payment status back to pending + for _, tx := range transactions { + updatePaymentStatus(t, testCtx, tx.ExternalID, data.PendingPaymentStatus) + } +} + +func updatePaymentStatus(t *testing.T, testCtx *testContext, paymentID string, status data.PaymentStatus) { + t.Helper() + + query := `UPDATE payments SET status = $1 WHERE id = $2` + result, err := testCtx.sdpModel.DBConnectionPool.ExecContext(testCtx.ctx, query, status, paymentID) + require.NoError(t, err) + rowsAffected, err := result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rowsAffected) +} + +func unsyncTransactions(t *testing.T, testCtx *testContext, transactionIDs []string) { + t.Helper() + + query := `UPDATE submitter_transactions SET synced_at = NULL WHERE id = ANY($1)` + _, err := testCtx.sdpModel.DBConnectionPool.ExecContext(testCtx.ctx, query, pq.Array(transactionIDs)) + require.NoError(t, err) +} + +type payloadToUpdateTSSTxToError struct { + transactionID string + statusMessages string +} + +func updateTSSTransactionsToError(t *testing.T, testCtx *testContext, txDataSlice []payloadToUpdateTSSTxToError) []txSubStore.Transaction { + t.Helper() + + var transactionIDs []string + var statusMessages []sql.NullString + for _, txData := range txDataSlice { + transactionIDs = append(transactionIDs, txData.transactionID) + statusMessages = append(statusMessages, sql.NullString{String: txData.statusMessages, Valid: txData.statusMessages != ""}) + } + + updatedTransactions := []txSubStore.Transaction{} + q := ` + UPDATE submitter_transactions + SET status = $1, status_message = u.status_message, completed_at = NOW() + FROM (SELECT UNNEST($2::text[]) as id, UNNEST($3::text[]) as status_message) as u + WHERE submitter_transactions.id = u.id + RETURNING *` + err := testCtx.sdpModel.DBConnectionPool.SelectContext(testCtx.ctx, &updatedTransactions, q, txSubStore.TransactionStatusError, pq.Array(transactionIDs), pq.Array(statusMessages)) + require.NoError(t, err) + + return updatedTransactions +} + +func Test_TSSMonitorService_RetryingPayment(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + testCtx := setupTestContext(t, dbConnectionPool) + ctx := testCtx.ctx + + paymentService := NewSendPaymentsService(testCtx.sdpModel) + monitorService := NewTSSMonitorService(testCtx.sdpModel) + + // clean test db + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + + // create fixtures + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{ + Name: "started disbursement", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + Amount: "100", + StellarTransactionID: "stellar-transaction-id-1", + StellarOperationID: "operation-id-1", + Status: data.ReadyPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + err := paymentService.SendBatchPayments(ctx, 1) + require.NoError(t, err) + + paymentDB, err := testCtx.sdpModel.Payment.Get(ctx, payment.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.PendingPaymentStatus, paymentDB.Status) + + transactions, err := testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + require.Len(t, transactions, 1) + + transaction := transactions[0] + assert.Equal(t, payment.ID, transaction.ExternalID) + assert.Equal(t, txSubStore.TransactionStatusPending, transaction.Status) + + // GIVEN a payment that fails to be sent + prepareTxsForSync(t, testCtx, transactions) + updatedTransaction := updateTSSTransactionsToError(t, testCtx, []payloadToUpdateTSSTxToError{ + {transactionID: transaction.ID, statusMessages: "Failing Test"}, + }) + require.Len(t, updatedTransaction, 1) + transaction = &updatedTransaction[0] + assert.Equal(t, payment.ID, transaction.ExternalID) + assert.Equal(t, txSubStore.TransactionStatusError, transaction.Status) + + // WHEN the monitor service is called + err = monitorService.MonitorTransactions(ctx, 1) + require.NoError(t, err) + + // THEN the payment is synced to the error state + paymentDB, err = testCtx.sdpModel.Payment.Get(ctx, paymentDB.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.FailedPaymentStatus, paymentDB.Status) + assert.Len(t, paymentDB.StatusHistory, 3) + assert.Equal(t, data.FailedPaymentStatus, paymentDB.StatusHistory[2].Status) + assert.Equal(t, "Failing Test", paymentDB.StatusHistory[2].StatusMessage) + + // AND the payment is retried + err = testCtx.sdpModel.Payment.RetryFailedPayments(ctx, "email@test.com", paymentDB.ID) + require.NoError(t, err) + + paymentDB, err = testCtx.sdpModel.Payment.Get(ctx, paymentDB.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.ReadyPaymentStatus, paymentDB.Status) + + // AND a new transaction is created for the payment + err = paymentService.SendBatchPayments(ctx, 1) + require.NoError(t, err) + + paymentDB, err = testCtx.sdpModel.Payment.Get(ctx, payment.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.PendingPaymentStatus, paymentDB.Status) + + transactions, err = testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + require.Len(t, transactions, 2) + + transaction1 := transactions[0] + transaction2 := transactions[1] + assert.Equal(t, txSubStore.TransactionStatusError, transaction1.Status) + assert.Equal(t, txSubStore.TransactionStatusPending, transaction2.Status) + + prepareTxsForSync(t, testCtx, transactions[1:]) + transaction2, err = testCtx.tssModel.Get(ctx, transaction2.ID) + require.NoError(t, err) + assert.Equal(t, txSubStore.TransactionStatusSuccess, transaction2.Status) + + err = monitorService.MonitorTransactions(ctx, 2) + require.NoError(t, err) + + paymentDB, err = testCtx.sdpModel.Payment.Get(ctx, paymentDB.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.SuccessPaymentStatus, paymentDB.Status) + assert.Len(t, paymentDB.StatusHistory, 6) + assert.Equal(t, data.SuccessPaymentStatus, paymentDB.StatusHistory[5].Status) + assert.Empty(t, paymentDB.StatusHistory[5].StatusMessage) +} + +func Test_TSSMonitorService_CompleteDisbursements(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + testCtx := setupTestContext(t, dbConnectionPool) + ctx := testCtx.ctx + + paymentService := NewSendPaymentsService(testCtx.sdpModel) + monitorService := NewTSSMonitorService(testCtx.sdpModel) + + // clean test db + data.DeleteAllPaymentsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllDisbursementFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiverWalletsFixtures(t, ctx, dbConnectionPool) + data.DeleteAllReceiversFixtures(t, ctx, dbConnectionPool) + data.DeleteAllAssetFixtures(t, ctx, dbConnectionPool) + data.DeleteAllWalletFixtures(t, ctx, dbConnectionPool) + data.DeleteAllCountryFixtures(t, ctx, dbConnectionPool) + + // create fixtures + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "BRA", "Brazil") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "Wallet", "https://www.wallet.com", "www.wallet.com", "wallet://") + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GABC65XJDMXTGPNZRCI6V3KOKKWVK55UEKGQLONRIVYPMEJNNQ45YOEE") + + receiver := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver.ID, wallet.ID, data.RegisteredReceiversWalletStatus) + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Disbursements, &data.Disbursement{ + Name: "started disbursement", + Status: data.StartedDisbursementStatus, + Asset: asset, + Wallet: wallet, + Country: country, + }) + + payment := data.CreatePaymentFixture(t, ctx, dbConnectionPool, testCtx.sdpModel.Payment, &data.Payment{ + Amount: "100", + StellarTransactionID: "stellar-transaction-id-2", + StellarOperationID: "operation-id-2", + Status: data.ReadyPaymentStatus, + Disbursement: disbursement, + ReceiverWallet: receiverWallet, + Asset: *asset, + }) + + err := paymentService.SendBatchPayments(ctx, 1) + require.NoError(t, err) + + paymentDB, err := testCtx.sdpModel.Payment.Get(ctx, payment.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.PendingPaymentStatus, paymentDB.Status) + + transactions, err := testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + require.Len(t, transactions, 1) + + transaction := transactions[0] + assert.Equal(t, payment.ID, transaction.ExternalID) + assert.Equal(t, txSubStore.TransactionStatusPending, transaction.Status) + + // GIVEN a payment that fails to be sent + prepareTxsForSync(t, testCtx, transactions) + updatedTransaction := updateTSSTransactionsToError(t, testCtx, []payloadToUpdateTSSTxToError{ + {transactionID: transaction.ID, statusMessages: "Failing Test"}, + }) + require.Len(t, updatedTransaction, 1) + transaction = &updatedTransaction[0] + assert.Equal(t, payment.ID, transaction.ExternalID) + assert.Equal(t, txSubStore.TransactionStatusError, transaction.Status) + + // WHEN the monitor service is called + err = monitorService.MonitorTransactions(ctx, 1) + require.NoError(t, err) + + // THEN the disbursement will not be completed + disbursement, err = testCtx.sdpModel.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + assert.Equal(t, data.StartedDisbursementStatus, disbursement.Status) + + // AND the payment is retried + err = testCtx.sdpModel.Payment.RetryFailedPayments(ctx, "email@test.com", paymentDB.ID) + require.NoError(t, err) + + paymentDB, err = testCtx.sdpModel.Payment.Get(ctx, paymentDB.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.ReadyPaymentStatus, paymentDB.Status) + + // AND a new transaction is created for the payment + err = paymentService.SendBatchPayments(ctx, 1) + require.NoError(t, err) + + paymentDB, err = testCtx.sdpModel.Payment.Get(ctx, payment.ID, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, data.PendingPaymentStatus, paymentDB.Status) + + transactions, err = testCtx.tssModel.GetAllByPaymentIDs(ctx, []string{payment.ID}) + require.NoError(t, err) + require.Len(t, transactions, 2) + + transaction1 := transactions[0] + transaction2 := transactions[1] + assert.Equal(t, txSubStore.TransactionStatusError, transaction1.Status) + assert.Equal(t, txSubStore.TransactionStatusPending, transaction2.Status) + + prepareTxsForSync(t, testCtx, transactions[1:]) + transaction2, err = testCtx.tssModel.Get(ctx, transaction2.ID) + require.NoError(t, err) + assert.Equal(t, txSubStore.TransactionStatusSuccess, transaction2.Status) + + // WHEN the monitor service is called again + err = monitorService.MonitorTransactions(ctx, 2) + require.NoError(t, err) + + // THEN disbursement gets completed + disbursement, err = testCtx.sdpModel.Disbursements.Get(ctx, dbConnectionPool, disbursement.ID) + require.NoError(t, err) + assert.Equal(t, data.CompletedDisbursementStatus, disbursement.Status) +} diff --git a/internal/statistics/calculate_statistics.go b/internal/statistics/calculate_statistics.go new file mode 100644 index 000000000..f06603f34 --- /dev/null +++ b/internal/statistics/calculate_statistics.go @@ -0,0 +1,386 @@ +package statistics + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +var ErrResourcesNotFound = errors.New("resources not found") + +type PaymentCounters struct { + Draft int64 `json:"draft"` + Ready int64 `json:"ready"` + Pending int64 `json:"pending"` + Paused int64 `json:"paused"` + Success int64 `json:"success"` + Failed int64 `json:"failed"` + Total int64 `json:"total"` +} + +type PaymentAmounts struct { + Draft string `json:"draft"` + Ready string `json:"ready"` + Pending string `json:"pending"` + Paused string `json:"paused"` + Success string `json:"success"` + Failed string `json:"failed"` + Average string `json:"average"` + Total string `json:"total"` +} + +type PaymentAmountsByAsset struct { + AssetCode string `json:"asset_code"` + PaymentAmounts PaymentAmounts `json:"payment_amounts"` +} + +type GeneralStatistics struct { + DisbursementsStatistics + TotalDisbursement int64 `json:"total_disbursements"` +} + +type DisbursementsStatistics struct { + PaymentCounters PaymentCounters `json:"payment_counters"` + PaymentAmountsByAsset []PaymentAmountsByAsset `json:"payment_amounts_by_asset"` + ReceiverWalletsCounters ReceiverWalletsCounters `json:"receiver_wallets_counters"` + TotalReceivers int64 `json:"total_receivers"` +} + +type ReceiverWalletsCounters struct { + Draft int64 `json:"draft"` + Ready int64 `json:"ready"` + Registered int64 `json:"registered"` + Flagged int64 `json:"flagged"` + Total int64 `json:"total"` +} + +// getPaymentsStats returns payment statistics aggregated by payment status, if a disbursement ID +// is sent in the parameters the payment stats will be calculated for a specific disbursement. +func getPaymentsStats(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) (*PaymentCounters, []PaymentAmountsByAsset, error) { + query := []string{ + 0: "SELECT code, status, Count(*), Sum(p.amount)", + 1: "FROM payments p", + 2: "JOIN assets a ON p.asset_id=a.id", + 3: "", + 4: "GROUP BY (a.code, p.status)", + 5: "ORDER BY (a.code);", + } + + var args []interface{} + if disbursementID != "" { + query[3] = "WHERE p.disbursement_id = $1" + args = append(args, disbursementID) + } + + rows, err := sqlExec.QueryxContext(ctx, strings.Join(query, " "), args...) + if err != nil { + return nil, nil, fmt.Errorf("getting payments data in getPaymentsStats: %w", err) + } + + defer db.CloseRows(ctx, rows) + + currentCode := "" + paymentCounters := PaymentCounters{} + paymentAmounts := PaymentAmounts{} + + paymentsAmountsByAsset := []PaymentAmountsByAsset{} + var totalAmount float64 + var totalCount int64 + + for rows.Next() { + var ( + code, status, amount string + count int64 + ) + + err = rows.Scan(&code, &status, &count, &amount) + if err != nil { + return nil, nil, fmt.Errorf("attributing values to rows in getPaymentsStats: %w", err) + } + + if currentCode != code { + + if currentCode != "" { + avg := totalAmount / float64(totalCount) + paymentAmounts.Total = utils.FloatToString(totalAmount) + paymentAmounts.Average = utils.FloatToString(avg) + totalAmount = 0 + totalCount = 0 + + paymentsAmountsByAsset = append( + paymentsAmountsByAsset, + PaymentAmountsByAsset{ + AssetCode: currentCode, + PaymentAmounts: paymentAmounts, + }, + ) + + paymentAmounts = PaymentAmounts{} + } + + currentCode = code + } + + switch data.PaymentStatus(status) { + case data.DraftPaymentStatus: + paymentCounters.Draft += count + paymentAmounts.Draft = amount + + case data.PendingPaymentStatus: + paymentCounters.Pending += count + paymentAmounts.Pending = amount + + case data.ReadyPaymentStatus: + paymentCounters.Ready += count + paymentAmounts.Ready = amount + + case data.SuccessPaymentStatus: + paymentCounters.Success += count + paymentAmounts.Success = amount + + case data.FailedPaymentStatus: + paymentCounters.Failed += count + paymentAmounts.Failed = amount + + case data.PausedPaymentStatus: + paymentCounters.Paused += count + paymentAmounts.Paused = amount + default: + return nil, nil, fmt.Errorf("status %v is not a valid payment status", status) + } + + paymentCounters.Total += count + + totalCount += count + if value, parseErr := strconv.ParseFloat(amount, 64); parseErr != nil { + return nil, nil, fmt.Errorf("error parsing payment amount: %w", err) + } else { + totalAmount += value + } + } + + if err = rows.Err(); err != nil { + return nil, nil, fmt.Errorf("end scanning: %w", err) + } + + if currentCode != "" { + avg := totalAmount / float64(totalCount) + paymentAmounts.Total = utils.FloatToString(totalAmount) + paymentAmounts.Average = utils.FloatToString(avg) + + paymentsAmountsByAsset = append( + paymentsAmountsByAsset, + PaymentAmountsByAsset{ + AssetCode: currentCode, + PaymentAmounts: paymentAmounts, + }, + ) + } + + return &paymentCounters, paymentsAmountsByAsset, nil +} + +// getReceiverWalletsStats returns receiver wallets statistics aggregated by receiver wallet status, if a disbursement +// ID is sent in the parameters the receiver wallet stats will be calculated for a specific disbursement. +func getReceiverWalletsStats(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) (*ReceiverWalletsCounters, error) { + query := []string{ + 0: "SELECT rw.status, Count(DISTINCT rw.receiver_id)", + 1: "FROM receiver_wallets rw", + 2: "LEFT JOIN payments p ON p.receiver_wallet_id=rw.id", + 3: "", + 4: "GROUP BY (rw.status);", + } + + var args []interface{} + if disbursementID != "" { + query[3] = "WHERE p.disbursement_id = $1" + args = append(args, disbursementID) + } + + rows, err := sqlExec.QueryxContext(ctx, strings.Join(query, " "), args...) + if err != nil { + return nil, fmt.Errorf("getting receivers wallet data by asset: %w", err) + } + + defer db.CloseRows(ctx, rows) + + receiverWalletsCounters := ReceiverWalletsCounters{} + + for rows.Next() { + var ( + status string + count int64 + ) + + err = rows.Scan(&status, &count) + + if err != nil { + return nil, fmt.Errorf("attributing values to rows: %w", err) + } + + switch data.ReceiversWalletStatus(status) { + case data.DraftReceiversWalletStatus: + receiverWalletsCounters.Draft = count + + case data.ReadyReceiversWalletStatus: + receiverWalletsCounters.Ready = count + + case data.RegisteredReceiversWalletStatus: + receiverWalletsCounters.Registered = count + + case data.FlaggedReceiversWalletStatus: + receiverWalletsCounters.Flagged = count + + default: + return nil, fmt.Errorf("status %v is not a valid receiver wallet status", status) + } + + receiverWalletsCounters.Total += count + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("end scanning: %w", err) + } + + return &receiverWalletsCounters, nil +} + +// getTotalReceivers returns total amount of receivers, if a disbursement ID is sent in the parameters +// then the total amount of receivers present in the specific disbursement is returned. +func getTotalReceivers(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) (int64, error) { + var args []interface{} + query := "SELECT COUNT(DISTINCT r.id) FROM receivers r" + + if disbursementID != "" { + query += " JOIN payments p ON p.receiver_id = r.id WHERE p.disbursement_id = $1" + args = append(args, disbursementID) + } + + var totalReceivers int64 + err := sqlExec.GetContext(ctx, &totalReceivers, query, args...) + if err != nil { + return 0, fmt.Errorf("getting total receiver data: %w", err) + } + + return totalReceivers, nil +} + +// getTotalDisbursements returns total amount of disbursements. +func getTotalDisbursements(ctx context.Context, sqlExec db.SQLExecuter) (totalDisbursement int64, err error) { + q := "SELECT COUNT(*) FROM disbursements" + err = sqlExec.GetContext(ctx, &totalDisbursement, q) + if err != nil { + return 0, fmt.Errorf("getting total disbursement data: %w", err) + } + + return totalDisbursement, nil +} + +// CalculateStatistics calculate statistics for all disbursements. +func CalculateStatistics(ctx context.Context, dbConnectionPool db.DBConnectionPool) (statistics *GeneralStatistics, err error) { + // Start transaction + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("starting transaction in CalculateStatistics: %w", err) + } + defer func() { + db.DBTxRollback(ctx, dbTx, err, "error in CalculateStatistics") + }() + + paymentCounters, paymentAmountByAsset, err := getPaymentsStats(ctx, dbTx, "") + if err != nil { + return nil, err + } + + receiverWalletsCounters, err := getReceiverWalletsStats(ctx, dbTx, "") + if err != nil { + return nil, err + } + + totalReceivers, err := getTotalReceivers(ctx, dbTx, "") + if err != nil { + return nil, err + } + + totalDisbursement, err := getTotalDisbursements(ctx, dbTx) + if err != nil { + return nil, err + } + + err = dbTx.Commit() + if err != nil { + return nil, fmt.Errorf("commiting transaction in CalculateStatistics: %w", err) + } + + statistics = &GeneralStatistics{TotalDisbursement: totalDisbursement} + statistics.PaymentCounters = *paymentCounters + statistics.PaymentAmountsByAsset = paymentAmountByAsset + statistics.ReceiverWalletsCounters = *receiverWalletsCounters + statistics.TotalReceivers = totalReceivers + return statistics, nil +} + +// CalculateStatisticsByDisbursement calculate statistics for a specific disbursement. +func CalculateStatisticsByDisbursement(ctx context.Context, dbConnectionPool db.DBConnectionPool, disbursementID string) (statistics *DisbursementsStatistics, err error) { + // Start transaction + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("starting transaction in CalculateStatisticsByDisbursement: %w", err) + } + defer func() { + db.DBTxRollback(ctx, dbTx, err, "error in CalculateStatisticsByDisbursement") + }() + + disbursementExists, err := checkIfDisbursementExists(ctx, dbTx, disbursementID) + if err != nil { + return nil, fmt.Errorf("checking if disbursement exists in CalculateStatisticsByDisbursement: %w", err) + } + if !disbursementExists { + return nil, ErrResourcesNotFound + } + + paymentCounters, paymentAmountByAsset, err := getPaymentsStats(ctx, dbTx, disbursementID) + if err != nil { + return nil, err + } + + receiverWalletsCounters, err := getReceiverWalletsStats(ctx, dbTx, disbursementID) + if err != nil { + return nil, err + } + + totalReceivers, err := getTotalReceivers(ctx, dbTx, disbursementID) + if err != nil { + return nil, err + } + + err = dbTx.Commit() + if err != nil { + return nil, fmt.Errorf("commiting transaction in CalculateStatisticsByDisbursement: %w", err) + } + + statistics = &DisbursementsStatistics{ + PaymentCounters: *paymentCounters, + PaymentAmountsByAsset: paymentAmountByAsset, + ReceiverWalletsCounters: *receiverWalletsCounters, + TotalReceivers: totalReceivers, + } + return statistics, nil +} + +func checkIfDisbursementExists(ctx context.Context, sqlExec db.SQLExecuter, disbursementID string) (exists bool, err error) { + // Check if the disbursement exists + query := "SELECT EXISTS(SELECT 1 FROM disbursements WHERE id = $1)" + err = sqlExec.QueryRowxContext(ctx, query, disbursementID).Scan(&exists) + if err != nil { + return false, fmt.Errorf("checking disbursement existence: %w", err) + } + + return exists, nil +} diff --git a/internal/statistics/calculate_statistics_test.go b/internal/statistics/calculate_statistics_test.go new file mode 100644 index 000000000..cca1ee486 --- /dev/null +++ b/internal/statistics/calculate_statistics_test.go @@ -0,0 +1,414 @@ +package statistics + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCalculateStatistics_emptyDatabase(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + t.Run("getPaymentsStats", func(t *testing.T) { + paymentsCounter, paymentsAmountByAsset, errPayments := getPaymentsStats(ctx, dbConnectionPool, "") + require.NoError(t, errPayments) + + // paymentsCounter assertions + assert.IsType(t, &PaymentCounters{}, paymentsCounter) + gotJsonCounter, errJson := json.Marshal(paymentsCounter) + require.NoError(t, errJson) + wantJsonCounter := `{ + "draft": 0, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 0, + "failed": 0, + "total": 0 + }` + assert.JSONEq(t, wantJsonCounter, string(gotJsonCounter)) + + // paymentsAmountByAsset assertions + assert.IsType(t, []PaymentAmountsByAsset{}, paymentsAmountByAsset) + gotJsonAmountByAsset, errJson := json.Marshal(paymentsAmountByAsset) + require.NoError(t, errJson) + wantJsonAmountByAsset := `[]` + assert.JSONEq(t, wantJsonAmountByAsset, string(gotJsonAmountByAsset)) + }) + + t.Run("getReceiverWalletsStats", func(t *testing.T) { + receiverWalletStats, errReceiver := getReceiverWalletsStats(ctx, dbConnectionPool, "") + require.NoError(t, errReceiver) + + // receiverWalletStats assertions + assert.IsType(t, &ReceiverWalletsCounters{}, receiverWalletStats) + gotJson, errJson := json.Marshal(receiverWalletStats) + require.NoError(t, errJson) + wantJson := `{ + "draft": 0, + "flagged": 0, + "ready": 0, + "registered": 0, + "total": 0 + }` + assert.JSONEq(t, wantJson, string(gotJson)) + }) + + t.Run("getTotalReceivers", func(t *testing.T) { + totalReceivers, err := getTotalReceivers(ctx, dbConnectionPool, "") + require.NoError(t, err) + assert.Equal(t, int64(0), totalReceivers) + }) + + t.Run("getTotalDisbursements", func(t *testing.T) { + totalDisbursements, err := getTotalDisbursements(ctx, dbConnectionPool) + require.NoError(t, err) + assert.Equal(t, int64(0), totalDisbursements) + }) +} + +func TestCalculateStatistics(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + models, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + asset1 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + receiver1 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet1 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver1.ID, wallet.ID, data.DraftReceiversWalletStatus) + + receiver2 := data.CreateReceiverFixture(t, ctx, dbConnectionPool, &data.Receiver{}) + receiverWallet2 := data.CreateReceiverWalletFixture(t, ctx, dbConnectionPool, receiver2.ID, wallet.ID, data.DraftReceiversWalletStatus) + + disbursement1 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 1", + Status: data.CompletedDisbursementStatus, + Asset: asset1, + Wallet: wallet, + Country: country, + }) + + stellarTransactionID, err := utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err := utils.RandomString(32) + require.NoError(t, err) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "10", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + Disbursement: disbursement1, + Asset: *asset1, + ReceiverWallet: receiverWallet1, + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "10", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.DraftPaymentStatus, + Disbursement: disbursement1, + Asset: *asset1, + ReceiverWallet: receiverWallet2, + }) + + t.Run("get receiver wallet stats", func(t *testing.T) { + receiverWalletStats, errReceiver := getReceiverWalletsStats(ctx, dbConnectionPool, "") + require.NoError(t, errReceiver) + + assert.IsType(t, &ReceiverWalletsCounters{}, receiverWalletStats) + + gotJson, errJson := json.Marshal(receiverWalletStats) + require.NoError(t, errJson) + + wantJson := `{ + "draft": 2, + "flagged": 0, + "ready": 0, + "registered": 0, + "total": 2 + }` + + assert.JSONEq(t, wantJson, string(gotJson)) + }) + + t.Run("get total disbursement", func(t *testing.T) { + totalDisbursement, errDisbursement := getTotalDisbursements(ctx, dbConnectionPool) + require.NoError(t, errDisbursement) + + assert.Equal(t, int64(1), totalDisbursement) + }) + + t.Run("get payment stats", func(t *testing.T) { + paymentsCounter, paymentsAmountByAsset, errPayments := getPaymentsStats(ctx, dbConnectionPool, "") + require.NoError(t, errPayments) + + assert.IsType(t, &PaymentCounters{}, paymentsCounter) + assert.IsType(t, []PaymentAmountsByAsset{}, paymentsAmountByAsset) + + gotJsonCounter, errJson := json.Marshal(paymentsCounter) + require.NoError(t, errJson) + + wantJsonCounter := `{ + "draft": 2, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 0, + "failed": 0, + "total": 2 + }` + + assert.JSONEq(t, wantJsonCounter, string(gotJsonCounter)) + + gotJsonAmountByAsset, errJson := json.Marshal(paymentsAmountByAsset) + require.NoError(t, errJson) + + wantJsonAmountByAsset := `[ + { + "asset_code": "USDC", + "payment_amounts": { + "draft": "20.0000000", + "ready": "", + "pending": "", + "paused": "", + "success": "", + "failed": "", + "average": "10.0000000", + "total": "20.0000000" + } + } + ]` + + assert.JSONEq(t, wantJsonAmountByAsset, string(gotJsonAmountByAsset)) + }) + + asset2 := data.CreateAssetFixture(t, ctx, dbConnectionPool, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + + disbursement2 := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, models.Disbursements, &data.Disbursement{ + Name: "disbursement 2", + Status: data.CompletedDisbursementStatus, + Asset: asset2, + Wallet: wallet, + Country: country, + }) + + stellarTransactionID, err = utils.RandomString(64) + require.NoError(t, err) + stellarOperationID, err = utils.RandomString(32) + require.NoError(t, err) + + data.CreatePaymentFixture(t, ctx, dbConnectionPool, models.Payment, &data.Payment{ + Amount: "10", + StellarTransactionID: stellarTransactionID, + StellarOperationID: stellarOperationID, + Status: data.SuccessPaymentStatus, + Disbursement: disbursement2, + Asset: *asset2, + ReceiverWallet: receiverWallet1, + }) + + t.Run("get payment stats with multiple assets codes", func(t *testing.T) { + paymentsCounter, paymentsAmountByAsset, err := getPaymentsStats(ctx, dbConnectionPool, "") + require.NoError(t, err) + + assert.IsType(t, &PaymentCounters{}, paymentsCounter) + assert.IsType(t, []PaymentAmountsByAsset{}, paymentsAmountByAsset) + + gotJsonCounter, err := json.Marshal(paymentsCounter) + require.NoError(t, err) + + wantJsonCounter := `{ + "draft": 2, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 1, + "failed": 0, + "total": 3 + }` + + assert.JSONEq(t, wantJsonCounter, string(gotJsonCounter)) + + gotJsonAmountByAsset, err := json.Marshal(paymentsAmountByAsset) + require.NoError(t, err) + + wantJsonAmountByAsset := `[ + { + "asset_code": "EURT", + "payment_amounts": { + "draft": "", + "ready": "", + "pending": "", + "paused": "", + "success": "10.0000000", + "failed": "", + "average": "10.0000000", + "total": "10.0000000" + } + }, + { + "asset_code": "USDC", + "payment_amounts": { + "draft": "20.0000000", + "ready": "", + "pending": "", + "paused": "", + "success": "", + "failed": "", + "average": "10.0000000", + "total": "20.0000000" + } + } + ]` + + assert.JSONEq(t, wantJsonAmountByAsset, string(gotJsonAmountByAsset)) + }) + + t.Run("get payment stats for specific disbursement", func(t *testing.T) { + paymentsCounter, paymentsAmountByAsset, err := getPaymentsStats(ctx, dbConnectionPool, disbursement2.ID) + require.NoError(t, err) + + assert.IsType(t, &PaymentCounters{}, paymentsCounter) + assert.IsType(t, []PaymentAmountsByAsset{}, paymentsAmountByAsset) + + gotJsonCounter, err := json.Marshal(paymentsCounter) + require.NoError(t, err) + + wantJsonCounter := `{ + "draft": 0, + "ready": 0, + "pending": 0, + "paused": 0, + "success": 1, + "failed": 0, + "total": 1 + }` + + assert.JSONEq(t, wantJsonCounter, string(gotJsonCounter)) + + gotJsonAmountByAsset, err := json.Marshal(paymentsAmountByAsset) + require.NoError(t, err) + + wantJsonAmountByAsset := `[ + { + "asset_code": "EURT", + "payment_amounts": { + "draft": "", + "ready": "", + "pending": "", + "paused": "", + "success": "10.0000000", + "failed": "", + "average": "10.0000000", + "total": "10.0000000" + } + } + ]` + + assert.JSONEq(t, wantJsonAmountByAsset, string(gotJsonAmountByAsset)) + }) + + t.Run("get receiver wallet stats for specific disbursement", func(t *testing.T) { + receiverWalletStats, err := getReceiverWalletsStats(ctx, dbConnectionPool, disbursement2.ID) + require.NoError(t, err) + + assert.IsType(t, &ReceiverWalletsCounters{}, receiverWalletStats) + + gotJson, err := json.Marshal(receiverWalletStats) + require.NoError(t, err) + + wantJson := `{ + "draft": 1, + "flagged": 0, + "ready": 0, + "registered": 0, + "total": 1 + }` + + assert.JSONEq(t, wantJson, string(gotJson)) + }) + + t.Run("get total receivers", func(t *testing.T) { + totalReceivers, err := getTotalReceivers(ctx, dbConnectionPool, "") + require.NoError(t, err) + assert.Equal(t, int64(2), totalReceivers) + }) + + t.Run("get total receivers with disbursement ID", func(t *testing.T) { + totalReceivers, err := getTotalReceivers(ctx, dbConnectionPool, disbursement2.ID) + require.NoError(t, err) + assert.Equal(t, int64(1), totalReceivers) + }) +} + +func Test_checkIfDisbursementExists(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + model, err := data.NewModels(dbConnectionPool) + require.NoError(t, err) + + t.Run("disbursement does not exist", func(t *testing.T) { + exists, err := checkIfDisbursementExists(context.Background(), dbConnectionPool, "non-existing-id") + require.NoError(t, err) + assert.False(t, exists) + }) + + t.Run("disbursement exists", func(t *testing.T) { + asset := data.CreateAssetFixture(t, ctx, dbConnectionPool, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZVV") + country := data.CreateCountryFixture(t, ctx, dbConnectionPool, "FRA", "France") + wallet := data.CreateWalletFixture(t, ctx, dbConnectionPool, "wallet1", "https://www.wallet.com", "www.wallet.com", "wallet1://") + + disbursement := data.CreateDisbursementFixture(t, ctx, dbConnectionPool, model.Disbursements, &data.Disbursement{ + Status: data.DraftDisbursementStatus, + StatusHistory: []data.DisbursementStatusHistoryEntry{ + { + Status: data.DraftDisbursementStatus, + UserID: "user1", + }, + }, + Asset: asset, + Country: country, + Wallet: wallet, + }) + exists, err := checkIfDisbursementExists(context.Background(), dbConnectionPool, disbursement.ID) + require.NoError(t, err) + assert.True(t, exists) + }) +} diff --git a/internal/transactionsubmission/README.md b/internal/transactionsubmission/README.md new file mode 100644 index 000000000..e9a0abde9 --- /dev/null +++ b/internal/transactionsubmission/README.md @@ -0,0 +1,158 @@ +# Transaction Submission Service + +The Transaction Submission Service (TSS) is a component that is responsible for submitting payment transactions to the Stellar Network. + +The SDP will directly 'queue' transactions (create transactions in the database) and the Transaction Submission Service will read these transactions and submit them to the Stellar Network. + +The Transaction Submission Service requires channel accounts to be seeded in storage in advanced. To learn how to fulfill this prerequisite, please refer to the [Channel Accounts Management](#channel-accounts-management) section below. + +## Transaction Submitter +### CLI Usage: `tss` +```sh +$ stellar-disbursement-platform tss --help +Run the Transaction Submission Service + +Usage: + stellar-disbursement-platform tss [flags] + +Flags: + --crash-tracker-type string Crash tracker type. Options: "SENTRY", "DRY_RUN" (CRASH_TRACKER_TYPE) (default "DRY_RUN") + --distribution-seed string The private key of the Stellar account used to disburse funds (DISTRIBUTION_SEED) + -h, --help help for tss + --horizon-url string Horizon URL (HORIZON_URL) (default "https://horizon-testnet.stellar.org/") + --max-base-fee int The max base fee for submitting a Stellar transaction (MAX_BASE_FEE) (default 100) + --num-channel-accounts int Number of channel accounts to utilize for transaction submission (NUM_CHANNEL_ACCOUNTS) (default 2) + --queue-polling-interval int Polling interval (seconds) to query the database for pending transactions to process (QUEUE_POLLING_INTERVAL) (default 6) + --tss-metrics-port int Port where the metrics server will be listening on. Default: 9002" (TSS_METRICS_PORT) (default 9002) + --tss-metrics-type string Metric monitor type. Options: "TSS_PROMETHEUS" (TSS_METRICS_TYPE) (default "TSS_PROMETHEUS") + +Global Flags: + --base-url string The SDP UI base URL. (BASE_URL) (default "http://localhost:8000") + --database-url string Postgres DB URL (DATABASE_URL) (default "postgres://localhost:5432/sdp?sslmode=disable") + --environment string The environment where the application is running. Example: "development", "staging", "production". (ENVIRONMENT) (default "development") + --log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE") + --network-passphrase string The Stellar network passphrase (NETWORK_PASSPHRASE) (default "Test SDF Network ; September 2015") + --sentry-dsn string The DSN (client key) of the Sentry project. If not provided, Sentry will not be used. (SENTRY_DSN) +``` + +## Channel Accounts Management + +Channel Accounts are used to increase throughput when submitting transaction to the Stellar Network, and are a prerequisite for using TSS. This CLI tools should enable all use cases for management of Channel Accounts (both onchain and in the database). + +### CLI Usage: `channel-accounts` +```sh +$ stellar-disbursement-platform channel-accounts --help +Channel accounts related commands + +Usage: + stellar-disbursement-platform channel-accounts [command] + +Available Commands: + create Create channel accounts + delete Delete a specified channel account from storage and on the network + ensure Ensure we are managing exactly the number of channel accounts equal to some specified count by dynamically increasing or decreasing the number of managed channel accounts in storage and onchain + verify Verify the existence of all channel accounts in the database on the Stellar newtwork + view View all channel accounts currently managed in the database + +Flags: + -h, --help help for channel-accounts + --horizon-url string Horizon URL" (HORIZON_URL) (default "https://horizon-testnet.stellar.org/") + +Global Flags: + --base-url string The SDP UI base URL. (BASE_URL) (default "http://localhost:8000") + --database-url string Postgres DB URL (DATABASE_URL) (default "postgres://localhost:5432/sdp?sslmode=disable") + --environment string The environment where the application is running. Example: "development", "staging", "production". (ENVIRONMENT) (default "development") + --log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE") + --network-passphrase string The Stellar network passphrase (NETWORK_PASSPHRASE) (default "Test SDF Network ; September 2015") + --sentry-dsn string The DSN (client key) of the Sentry project. If not provided, Sentry will not be used. (SENTRY_DSN) + +Use "stellar-disbursement-platform channel-accounts [command] --help" for more information about a command. +``` + +### CLI Usage: `channel-accounts create` +```sh +channel-accounts create --help +Usage: + stellar-disbursement-platform channel-accounts create [flags] + +Flags: + --distribution-seed string The private key of the Stellar account that will be used to sponsor the channel accounts (DISTRIBUTION_SEED) + --encrypt-key Whether or not to encrypt the private key for storage (ENCRYPT_KEY) (default true) + -h, --help help for create + --max-base-fee int The max base fee for submitting a stellar transaction (MAX_BASE_FEE) (default 100) + --num-channel-accounts-create int The desired number of channel accounts to be created (NUM_CHANNEL_ACCOUNTS_CREATE) (default 1) +``` + +### CLI Usage: `channel-accounts ensure` +```sh +channel-accounts ensure --help +Usage: + stellar-disbursement-platform channel-accounts ensure [flags] + +Flags: + --distribution-seed string The private key of the Stellar account used to sponsor existing channel accounts (DISTRIBUTION_SEED) + --encrypt-key Whether or not to encrypt the private key for storage (ENCRYPT_KEY) (default true) + -h, --help help for ensure + --max-base-fee int The max base fee for submitting a stellar transaction (MAX_BASE_FEE) (default 100) + --num-channel-accounts-ensure int The desired number of channel accounts to manage (NUM_CHANNEL_ACCOUNTS_ENSURE) (default 1) +``` + +### CLI Usage: `channel-accounts delete` +```sh +channel-accounts delete --help +Usage: + stellar-disbursement-platform channel-accounts delete [flags] + +Flags: + --channel-account-id string The ID of the channel account to delete (CHANNEL_ACCOUNT_ID) + --delete-all-accounts Delete all managed channel accoounts in the database and on the network (DELETE_ALL_ACCOUNTS) + --distribution-seed string The private key of the Stellar account used to sponsor the channel account specified (DISTRIBUTION_SEED) + -h, --help help for delete + --max-base-fee int The max base fee for submitting a stellar transaction (MAX_BASE_FEE) (default 100) +``` + +### CLI Usage: `channel-accounts verify` +```sh +channel-accounts verify --help +Usage: + stellar-disbursement-platform channel-accounts verify [flags] + +Flags: + --delete-invalid-accounts Delete channel accounts from storage that are verified to be invalid on the network (DELETE_INVALID_ACCOUNTS) + -h, --help help for verify +``` + +### CLI Usage: `channel-accounts verify` +```sh +channel-accounts verify --help +Usage: + stellar-disbursement-platform channel-accounts verify [flags] + +Flags: + --delete-invalid-accounts Delete channel accounts from storage that are verified to be invalid on the network (DELETE_INVALID_ACCOUNTS) + -h, --help help for verify +``` + +### CLI Usage: `channel-accounts view` +```sh +channel-accounts view --help +Usage: + stellar-disbursement-platform channel-accounts view [flags] + +Flags: + -h, --help help for view +``` + +## Testing +### Mocks +TSS unit tests rely on mocks of its interfaces auto-generated by mockery. For installation instructions, see [here](https://vektra.github.io/mockery/installation/). + +Refer to the output to learn how to annotate interfaces and about the different flags that you can leverage to manipulate the output. +``` +mockery --help +``` + +To generate the mocks +``` +go generate ./... +``` diff --git a/internal/transactionsubmission/engine/ledger_number_tracker.go b/internal/transactionsubmission/engine/ledger_number_tracker.go new file mode 100644 index 000000000..4ebec2c13 --- /dev/null +++ b/internal/transactionsubmission/engine/ledger_number_tracker.go @@ -0,0 +1,83 @@ +package engine + +import ( + "fmt" + "sync" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +const ( + MaxLedgerAge = 10 * time.Second + IncrementForMaxLedgerBounds = 10 +) + +// LedgerNumberTracker is a helper struct that keeps track of the current ledger number. +// +//go:generate mockery --name=LedgerNumberTracker --case=underscore --structname=MockLedgerNumberTracker +type LedgerNumberTracker interface { + GetLedgerNumber() (int, error) + GetLedgerBounds() (*txnbuild.LedgerBounds, error) +} + +type DefaultLedgerNumberTracker struct { + maxLedgerAge time.Duration + hClient horizonclient.ClientInterface + ledgerNumber int + lastUpdatedAt time.Time + // mutex is used to make sure only one call to getLedgerNumberFromHorizon() is running at a time and to prevent running it too often. + mutex sync.Mutex +} + +func NewLedgerNumberTracker(hClient horizonclient.ClientInterface) (*DefaultLedgerNumberTracker, error) { + if hClient == nil { + return nil, fmt.Errorf("horizon client cannot be nil") + } + + return &DefaultLedgerNumberTracker{ + hClient: hClient, + maxLedgerAge: MaxLedgerAge, + }, nil +} + +func (se *DefaultLedgerNumberTracker) GetLedgerNumber() (int, error) { + se.mutex.Lock() + defer se.mutex.Unlock() + + if time.Since(se.lastUpdatedAt) > se.maxLedgerAge { + ledgerNumber, err := se.getLedgerNumberFromHorizon() + if err != nil { + return 0, fmt.Errorf("getting ledger number from horizon: %w", err) + } else { + se.ledgerNumber = ledgerNumber + se.lastUpdatedAt = time.Now() + } + } + + return se.ledgerNumber, nil +} + +func (se *DefaultLedgerNumberTracker) getLedgerNumberFromHorizon() (int, error) { + ledger, err := se.hClient.Root() + if err != nil { + return 0, utils.NewHorizonErrorWrapper(err) + } + + return int(ledger.HorizonSequence), nil +} + +func (se *DefaultLedgerNumberTracker) GetLedgerBounds() (*txnbuild.LedgerBounds, error) { + ledgerNumber, err := se.GetLedgerNumber() + if err != nil { + return nil, fmt.Errorf("getting ledger number: %w", err) + } + + return &txnbuild.LedgerBounds{ + MaxLedger: uint32(ledgerNumber + IncrementForMaxLedgerBounds), + }, nil +} + +var _ LedgerNumberTracker = (*DefaultLedgerNumberTracker)(nil) diff --git a/internal/transactionsubmission/engine/ledger_number_tracker_test.go b/internal/transactionsubmission/engine/ledger_number_tracker_test.go new file mode 100644 index 000000000..a4ff42938 --- /dev/null +++ b/internal/transactionsubmission/engine/ledger_number_tracker_test.go @@ -0,0 +1,269 @@ +package engine + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/support/render/problem" + "github.com/stellar/go/txnbuild" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewLedgerNumberTracker(t *testing.T) { + mockHorizonClient := &horizonclient.MockClient{} + + testCases := []struct { + name string + hClient horizonclient.ClientInterface + wantErrContains string + wantResult LedgerNumberTracker + }{ + { + name: "returns an error if the horizon client is nil", + hClient: nil, + wantErrContains: "horizon client cannot be nil", + }, + { + name: "πŸŽ‰ successfully provides new LedgerNumberTracker", + hClient: mockHorizonClient, + wantResult: &DefaultLedgerNumberTracker{ + hClient: mockHorizonClient, + maxLedgerAge: MaxLedgerAge, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ledgerNumberTracker, err := NewLedgerNumberTracker(tc.hClient) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, ledgerNumberTracker) + } else { + require.NoError(t, err) + assert.NotNil(t, ledgerNumberTracker) + assert.Equal(t, tc.wantResult, ledgerNumberTracker) + } + }) + } + + mockHorizonClient.AssertExpectations(t) +} + +func Test_LedgerNumberTracker_getLedgerNumberFromHorizon(t *testing.T) { + testCases := []struct { + name string + horizonResponseError error + wantErrContains string + horizonResponseRoot horizon.Root + horizonResponseLedgerNumber int + wantResult int + }{ + { + name: "returns an error if horizon returns a horizon error", + horizonResponseError: horizonclient.Error{ + Problem: problem.P{ + Title: "Foo", + Type: "bar", + Status: http.StatusTooManyRequests, + }, + }, + wantErrContains: "horizon response error: StatusCode=429, Type=bar, Title=Foo", + }, + { + name: "returns an error if horizon returns an unexpected error", + horizonResponseError: fmt.Errorf("some random error"), + wantErrContains: "horizon response error: some random error", + }, + { + name: "πŸŽ‰ successfully gets the latest ledger number", + horizonResponseRoot: horizon.Root{HorizonSequence: 1234}, + wantResult: 1234, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockHorizonClient := &horizonclient.MockClient{} + mockHorizonClient.On("Root").Return(tc.horizonResponseRoot, tc.horizonResponseError).Once() + + ledgerNumberTracker, err := NewLedgerNumberTracker(mockHorizonClient) + require.NoError(t, err) + + ledgerNumber, err := ledgerNumberTracker.getLedgerNumberFromHorizon() + if tc.horizonResponseError != nil { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Equal(t, 0, ledgerNumber) + } else { + require.NoError(t, err) + assert.Equal(t, tc.wantResult, ledgerNumber) + } + + mockHorizonClient.AssertExpectations(t) + }) + } +} + +func Test_LedgerNumberTracker_GetLedgerNumber(t *testing.T) { + const startingLedgerNumber = 1230 + + testCases := []struct { + name string + startingLedgerNumber int + isNumberExpired bool + horizonResponseRoot horizon.Root + horizonResponseError error + wantErrContains string + horizonResponseLedgerNumber int + wantResult int + }{ + { + name: "returns an error if horizon returns a horizon error (EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: true, + horizonResponseError: horizonclient.Error{ + Problem: problem.P{ + Title: "Foo", + Type: "bar", + Status: http.StatusTooManyRequests, + }, + }, + wantErrContains: "getting ledger number from horizon: horizon response error: StatusCode=429, Type=bar, Title=Foo", + }, + { + name: "returns an error if horizon returns an unexpected error (EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: true, + horizonResponseError: fmt.Errorf("some random error"), + wantErrContains: "getting ledger number from horizon: horizon response error: some random error", + }, + { + name: "πŸŽ‰ successfully gets the latest ledger number (EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: true, + horizonResponseRoot: horizon.Root{HorizonSequence: 1234}, + wantResult: 1234, + }, + { + name: "πŸŽ‰ successfully gets the latest ledger number (NOT EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: false, + wantResult: startingLedgerNumber, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockHorizonClient := &horizonclient.MockClient{} + if tc.isNumberExpired { + mockHorizonClient.On("Root").Return(tc.horizonResponseRoot, tc.horizonResponseError).Once() + } + + ledgerNumberTracker, err := NewLedgerNumberTracker(mockHorizonClient) + require.NoError(t, err) + ledgerNumberTracker.ledgerNumber = tc.startingLedgerNumber + if tc.isNumberExpired { + ledgerNumberTracker.lastUpdatedAt = time.Now().Add(-ledgerNumberTracker.maxLedgerAge - time.Second) + } else { + ledgerNumberTracker.lastUpdatedAt = time.Now().Add(-ledgerNumberTracker.maxLedgerAge + time.Second) + } + initialLedgerLastUpdatedAt := ledgerNumberTracker.lastUpdatedAt + + ledgerNumber, err := ledgerNumberTracker.GetLedgerNumber() + if tc.horizonResponseError != nil { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Equal(t, 0, ledgerNumber) + assert.Equal(t, initialLedgerLastUpdatedAt, ledgerNumberTracker.lastUpdatedAt) + } else { + require.NoError(t, err) + assert.Equal(t, tc.wantResult, ledgerNumber) + if tc.isNumberExpired { + assert.NotEqual(t, initialLedgerLastUpdatedAt, ledgerNumberTracker.lastUpdatedAt) + } else { + assert.Equal(t, initialLedgerLastUpdatedAt, ledgerNumberTracker.lastUpdatedAt) + } + } + + mockHorizonClient.AssertExpectations(t) + }) + } +} + +func Test_LedgerNumberTracker_GetLedgerBounds(t *testing.T) { + const startingLedgerNumber = 1230 + + testCases := []struct { + name string + startingLedgerNumber int + isNumberExpired bool + horizonResponseRoot horizon.Root + horizonResponseError error + wantErrContains string + wantResult *txnbuild.LedgerBounds + }{ + { + name: "returns an error if horizon returns a horizon error (EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: true, + horizonResponseError: horizonclient.Error{ + Problem: problem.P{ + Title: "Foo", + Type: "bar", + Status: http.StatusTooManyRequests, + }, + }, + wantErrContains: "getting ledger number: getting ledger number from horizon: horizon response error: StatusCode=429, Type=bar, Title=Foo", + }, + { + name: "πŸŽ‰ successfully gets the latest ledger number (EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: true, + horizonResponseRoot: horizon.Root{HorizonSequence: 1234}, + wantResult: &txnbuild.LedgerBounds{MaxLedger: 1234 + IncrementForMaxLedgerBounds}, + }, + { + name: "πŸŽ‰ successfully gets the latest ledger number (NOT EXPIRED)", + startingLedgerNumber: startingLedgerNumber, + isNumberExpired: false, + wantResult: &txnbuild.LedgerBounds{MaxLedger: startingLedgerNumber + IncrementForMaxLedgerBounds}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockHorizonClient := &horizonclient.MockClient{} + if tc.isNumberExpired { + mockHorizonClient.On("Root").Return(tc.horizonResponseRoot, tc.horizonResponseError).Once() + } + + ledgerNumberTracker, err := NewLedgerNumberTracker(mockHorizonClient) + require.NoError(t, err) + ledgerNumberTracker.ledgerNumber = tc.startingLedgerNumber + if tc.isNumberExpired { + ledgerNumberTracker.lastUpdatedAt = time.Now().Add(-ledgerNumberTracker.maxLedgerAge - time.Second) + } else { + ledgerNumberTracker.lastUpdatedAt = time.Now().Add(-ledgerNumberTracker.maxLedgerAge + time.Second) + } + + ledgerBounds, err := ledgerNumberTracker.GetLedgerBounds() + if tc.horizonResponseError != nil { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, ledgerBounds) + } else { + require.NoError(t, err) + assert.Equal(t, tc.wantResult, ledgerBounds) + } + + mockHorizonClient.AssertExpectations(t) + }) + } +} diff --git a/internal/transactionsubmission/engine/mocks/ledger_number_tracker.go b/internal/transactionsubmission/engine/mocks/ledger_number_tracker.go new file mode 100644 index 000000000..5eca14209 --- /dev/null +++ b/internal/transactionsubmission/engine/mocks/ledger_number_tracker.go @@ -0,0 +1,78 @@ +// Code generated by mockery v2.27.1. DO NOT EDIT. + +package mocks + +import ( + txnbuild "github.com/stellar/go/txnbuild" + mock "github.com/stretchr/testify/mock" +) + +// MockLedgerNumberTracker is an autogenerated mock type for the LedgerNumberTracker type +type MockLedgerNumberTracker struct { + mock.Mock +} + +// GetLedgerBounds provides a mock function with given fields: +func (_m *MockLedgerNumberTracker) GetLedgerBounds() (*txnbuild.LedgerBounds, error) { + ret := _m.Called() + + var r0 *txnbuild.LedgerBounds + var r1 error + if rf, ok := ret.Get(0).(func() (*txnbuild.LedgerBounds, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *txnbuild.LedgerBounds); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*txnbuild.LedgerBounds) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetLedgerNumber provides a mock function with given fields: +func (_m *MockLedgerNumberTracker) GetLedgerNumber() (int, error) { + ret := _m.Called() + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func() (int, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewMockLedgerNumberTracker interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockLedgerNumberTracker creates a new instance of MockLedgerNumberTracker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockLedgerNumberTracker(t mockConstructorTestingTNewMockLedgerNumberTracker) *MockLedgerNumberTracker { + mock := &MockLedgerNumberTracker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/transactionsubmission/engine/mocks/signature_service.go b/internal/transactionsubmission/engine/mocks/signature_service.go new file mode 100644 index 000000000..167c88d75 --- /dev/null +++ b/internal/transactionsubmission/engine/mocks/signature_service.go @@ -0,0 +1,154 @@ +// Code generated by mockery v2.23.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + keypair "github.com/stellar/go/keypair" + + mock "github.com/stretchr/testify/mock" + + txnbuild "github.com/stellar/go/txnbuild" +) + +// MockSignatureService is an autogenerated mock type for the SignatureService type +type MockSignatureService struct { + mock.Mock +} + +// BatchInsert provides a mock function with given fields: ctx, kps, shouldEncryptSeed, currLedgerNumber +func (_m *MockSignatureService) BatchInsert(ctx context.Context, kps []*keypair.Full, shouldEncryptSeed bool, currLedgerNumber int) error { + ret := _m.Called(ctx, kps, shouldEncryptSeed, currLedgerNumber) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []*keypair.Full, bool, int) error); ok { + r0 = rf(ctx, kps, shouldEncryptSeed, currLedgerNumber) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Delete provides a mock function with given fields: ctx, publicKey, currLedgerNumber +func (_m *MockSignatureService) Delete(ctx context.Context, publicKey string, currLedgerNumber int) error { + ret := _m.Called(ctx, publicKey, currLedgerNumber) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int) error); ok { + r0 = rf(ctx, publicKey, currLedgerNumber) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DistributionAccount provides a mock function with given fields: +func (_m *MockSignatureService) DistributionAccount() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// NetworkPassphrase provides a mock function with given fields: +func (_m *MockSignatureService) NetworkPassphrase() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// SignFeeBumpStellarTransaction provides a mock function with given fields: ctx, feeBumpStellarTx, stellarAccounts +func (_m *MockSignatureService) SignFeeBumpStellarTransaction(ctx context.Context, feeBumpStellarTx *txnbuild.FeeBumpTransaction, stellarAccounts ...string) (*txnbuild.FeeBumpTransaction, error) { + _va := make([]interface{}, len(stellarAccounts)) + for _i := range stellarAccounts { + _va[_i] = stellarAccounts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, feeBumpStellarTx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *txnbuild.FeeBumpTransaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *txnbuild.FeeBumpTransaction, ...string) (*txnbuild.FeeBumpTransaction, error)); ok { + return rf(ctx, feeBumpStellarTx, stellarAccounts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *txnbuild.FeeBumpTransaction, ...string) *txnbuild.FeeBumpTransaction); ok { + r0 = rf(ctx, feeBumpStellarTx, stellarAccounts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*txnbuild.FeeBumpTransaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *txnbuild.FeeBumpTransaction, ...string) error); ok { + r1 = rf(ctx, feeBumpStellarTx, stellarAccounts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SignStellarTransaction provides a mock function with given fields: ctx, stellarTx, stellarAccounts +func (_m *MockSignatureService) SignStellarTransaction(ctx context.Context, stellarTx *txnbuild.Transaction, stellarAccounts ...string) (*txnbuild.Transaction, error) { + _va := make([]interface{}, len(stellarAccounts)) + for _i := range stellarAccounts { + _va[_i] = stellarAccounts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, stellarTx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *txnbuild.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *txnbuild.Transaction, ...string) (*txnbuild.Transaction, error)); ok { + return rf(ctx, stellarTx, stellarAccounts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *txnbuild.Transaction, ...string) *txnbuild.Transaction); ok { + r0 = rf(ctx, stellarTx, stellarAccounts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*txnbuild.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *txnbuild.Transaction, ...string) error); ok { + r1 = rf(ctx, stellarTx, stellarAccounts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewMockSignatureService creates a new instance of MockSignatureService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSignatureService(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSignatureService { + mock := &MockSignatureService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/transactionsubmission/engine/signature_service.go b/internal/transactionsubmission/engine/signature_service.go new file mode 100644 index 000000000..86ad12849 --- /dev/null +++ b/internal/transactionsubmission/engine/signature_service.go @@ -0,0 +1,203 @@ +package engine + +import ( + "context" + "fmt" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/strkey" + "github.com/stellar/go/txnbuild" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +//go:generate mockery --name=SignatureService --case=underscore --structname=MockSignatureService +type SignatureService interface { + DistributionAccount() string + NetworkPassphrase() string + SignStellarTransaction(ctx context.Context, stellarTx *txnbuild.Transaction, stellarAccounts ...string) (signedStellarTx *txnbuild.Transaction, err error) + SignFeeBumpStellarTransaction(ctx context.Context, feeBumpStellarTx *txnbuild.FeeBumpTransaction, stellarAccounts ...string) (signedFeeBumpStellarTx *txnbuild.FeeBumpTransaction, err error) + BatchInsert(ctx context.Context, kps []*keypair.Full, shouldEncryptSeed bool, currLedgerNumber int) (err error) + Delete(ctx context.Context, publicKey string, currLedgerNumber int) error +} + +type DefaultSignatureService struct { + networkPassphrase string + distributionAccount string + distributionKP *keypair.Full + dbConnectionPool db.DBConnectionPool + chAccModel store.ChannelAccountStore + encrypter utils.PrivateKeyEncrypter + encrypterPass string +} + +// NewDefaultSignatureService returns a new DefaultSignatureService instance. +func NewDefaultSignatureService(networkPassphrase string, dbConnectionPool db.DBConnectionPool, distributionSeed string, chAccStore store.ChannelAccountStore, encrypter utils.PrivateKeyEncrypter, encrypterPass string) (*DefaultSignatureService, error) { + if dbConnectionPool == nil { + return nil, fmt.Errorf("db connection pool cannot be nil") + } + if chAccStore == nil { + return nil, fmt.Errorf("channel account store cannot be nil") + } + + if (networkPassphrase != network.TestNetworkPassphrase) && (networkPassphrase != network.PublicNetworkPassphrase) { + return nil, fmt.Errorf("invalid network passphrase: %q", networkPassphrase) + } + + distributionKP, err := keypair.ParseFull(distributionSeed) + if err != nil { + return nil, fmt.Errorf("parsing distribution seed: %w", err) + } + + if encrypter == nil { + return nil, fmt.Errorf("private key encrypter cannot be nil") + } + + if encrypterPass == "" { + return nil, fmt.Errorf("private key encrypter passphrase cannot be empty") + } + + return &DefaultSignatureService{ + networkPassphrase: networkPassphrase, + distributionAccount: distributionKP.Address(), + distributionKP: distributionKP, + dbConnectionPool: dbConnectionPool, + chAccModel: chAccStore, + encrypter: encrypter, + encrypterPass: encrypterPass, + }, nil +} + +func (ds *DefaultSignatureService) DistributionAccount() string { + return ds.distributionAccount +} + +func (ds *DefaultSignatureService) NetworkPassphrase() string { + return ds.networkPassphrase +} + +func (ds *DefaultSignatureService) getKPsForAccounts(ctx context.Context, stellarAccounts ...string) ([]*keypair.Full, error) { + if len(stellarAccounts) == 0 { + return nil, fmt.Errorf("no accounts provided") + } + + accountsAlreadyAccountedFor := map[string]struct{}{} + kps := []*keypair.Full{} + for i, account := range stellarAccounts { + if _, ok := accountsAlreadyAccountedFor[account]; ok { + continue + } + accountsAlreadyAccountedFor[account] = struct{}{} + + if account == "" { + return nil, fmt.Errorf("account %d is empty", i) + } + + if account == ds.DistributionAccount() { + kps = append(kps, ds.distributionKP) + continue + } + + // Can return ErrRecordNotFound + chAcc, err := ds.chAccModel.Get(ctx, ds.dbConnectionPool, account, 0) + if err != nil { + return nil, fmt.Errorf("getting secret for channel account %q: %w", account, err) + } + + chAccPrivateKey := chAcc.PrivateKey + if !strkey.IsValidEd25519SecretSeed(chAccPrivateKey) { + chAccPrivateKey, err = ds.encrypter.Decrypt(chAccPrivateKey, ds.encrypterPass) + if err != nil { + return nil, fmt.Errorf("cannot decrypt private key: %w", err) + } + } + + kp, err := keypair.ParseFull(chAccPrivateKey) + if err != nil { + return nil, fmt.Errorf("parsing secret for channel account %q: %w", account, err) + } + kps = append(kps, kp) + } + + return kps, nil +} + +func (ds *DefaultSignatureService) SignStellarTransaction(ctx context.Context, stellarTx *txnbuild.Transaction, stellarAccounts ...string) (signedStellarTx *txnbuild.Transaction, err error) { + if stellarTx == nil { + return nil, fmt.Errorf("stellarTx cannot be nil") + } + + kps, err := ds.getKPsForAccounts(ctx, stellarAccounts...) + if err != nil { + return nil, fmt.Errorf("getting keypairs for accounts %v: %w", stellarAccounts, err) + } + + signedStellarTx, err = stellarTx.Sign(ds.NetworkPassphrase(), kps...) + if err != nil { + return nil, fmt.Errorf("signing transaction: %w", err) + } + + return signedStellarTx, nil +} + +func (ds *DefaultSignatureService) SignFeeBumpStellarTransaction(ctx context.Context, feeBumpStellarTx *txnbuild.FeeBumpTransaction, stellarAccounts ...string) (signedFeeBumpStellarTx *txnbuild.FeeBumpTransaction, err error) { + if feeBumpStellarTx == nil { + return nil, fmt.Errorf("stellarTx cannot be nil") + } + + kps, err := ds.getKPsForAccounts(ctx, stellarAccounts...) + if err != nil { + return nil, fmt.Errorf("getting keypairs for accounts %v: %w", stellarAccounts, err) + } + + signedFeeBumpStellarTx, err = feeBumpStellarTx.Sign(ds.NetworkPassphrase(), kps...) + if err != nil { + return nil, fmt.Errorf("signing transaction: %w", err) + } + + return signedFeeBumpStellarTx, nil +} + +func (ds *DefaultSignatureService) BatchInsert(ctx context.Context, kps []*keypair.Full, shouldEncryptSeed bool, currLedgerNumber int) (err error) { + if len(kps) == 0 { + return fmt.Errorf("no keypairs provided") + } + + batchInsertPayload := []*store.ChannelAccount{} + for _, kp := range kps { + publicKey := kp.Address() + privateKey := kp.Seed() + if shouldEncryptSeed { + privateKey, err = ds.encrypter.Encrypt(privateKey, ds.encrypterPass) + if err != nil { + return fmt.Errorf("encrypting channel account private key: %w", err) + } + } + + batchInsertPayload = append(batchInsertPayload, &store.ChannelAccount{ + PublicKey: publicKey, + PrivateKey: privateKey, + }) + } + + err = ds.chAccModel.BatchInsertAndLock(ctx, batchInsertPayload, currLedgerNumber, currLedgerNumber+IncrementForMaxLedgerBounds) + if err != nil { + return fmt.Errorf("batch inserting channel accounts: %w", err) + } + + return nil +} + +func (ds *DefaultSignatureService) Delete(ctx context.Context, publicKey string, lockedToLedgerNumber int) error { + err := ds.chAccModel.DeleteIfLockedUntil(ctx, publicKey, lockedToLedgerNumber) + if err != nil { + return fmt.Errorf("deleting channel account %q from database: %w", publicKey, err) + } + + return nil +} + +var _ SignatureService = &DefaultSignatureService{} diff --git a/internal/transactionsubmission/engine/signature_service_test.go b/internal/transactionsubmission/engine/signature_service_test.go new file mode 100644 index 000000000..7ddcfe337 --- /dev/null +++ b/internal/transactionsubmission/engine/signature_service_test.go @@ -0,0 +1,582 @@ +package engine + +import ( + "context" + "math" + "testing" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewDefaultSignatureService(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + testCases := []struct { + name string + chAccountStore store.ChannelAccountStore + networkPassphrase string + distributionSeed string + encrypter utils.PrivateKeyEncrypter + encrypterPass string + wantErrContains string + }{ + { + name: "return an error if dbConnectionPool is nil", + wantErrContains: "channel account store cannot be nil", + }, + { + name: "return an error if networkPassphrase is invalid", + chAccountStore: chAccountStore, + networkPassphrase: "foo bar", + wantErrContains: `invalid network passphrase: "foo bar"`, + }, + { + name: "return an error if distributionSeed is invalid", + chAccountStore: chAccountStore, + networkPassphrase: network.TestNetworkPassphrase, + distributionSeed: "foo bar", + wantErrContains: "parsing distribution seed: base32 decode failed: illegal base32 data at input byte 7", + }, + { + name: "return an error if encrypter is nil", + chAccountStore: chAccountStore, + networkPassphrase: network.TestNetworkPassphrase, + distributionSeed: "SCPGNK3MRMXKNWGZ4ET3JZ6RUJIN7FMHT4ASVXDG7YPBL4WKBQNEL63F", + wantErrContains: "private key encrypter cannot be nil", + }, + { + name: "return an error if encrypterPass is empty", + chAccountStore: chAccountStore, + networkPassphrase: network.TestNetworkPassphrase, + distributionSeed: "SCPGNK3MRMXKNWGZ4ET3JZ6RUJIN7FMHT4ASVXDG7YPBL4WKBQNEL63F", + encrypter: &utils.PrivateKeyEncrypterMock{}, + wantErrContains: "private key encrypter passphrase cannot be empty", + }, + { + name: "πŸŽ‰ Successfully instantiates a new default signature service", + chAccountStore: chAccountStore, + networkPassphrase: network.TestNetworkPassphrase, + encrypter: &utils.PrivateKeyEncrypterMock{}, + encrypterPass: "SCPGNK3MRMXKNWGZ4ET3JZ6RUJIN7FMHT4ASVXDG7YPBL4WKBQNEL63F", + distributionSeed: "SCPGNK3MRMXKNWGZ4ET3JZ6RUJIN7FMHT4ASVXDG7YPBL4WKBQNEL63F", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sigService, err := NewDefaultSignatureService(tc.networkPassphrase, dbConnectionPool, tc.distributionSeed, tc.chAccountStore, tc.encrypter, tc.encrypterPass) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, sigService) + } else { + require.NoError(t, err) + assert.NotNil(t, sigService) + } + }) + } +} + +func Test_DefaultSignatureService_DistributionAccount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + // test with the first KP: + distributionKP, err := keypair.Random() + require.NoError(t, err) + defaultSigService, err := NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, &utils.PrivateKeyEncrypterMock{}, distributionKP.Seed()) + require.NoError(t, err) + require.Equal(t, distributionKP.Address(), defaultSigService.DistributionAccount()) + + // test with the second KP, to make sure it's changing accordingly: + distributionKP, err = keypair.Random() + require.NoError(t, err) + defaultSigService, err = NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, &utils.PrivateKeyEncrypterMock{}, distributionKP.Seed()) + require.NoError(t, err) + require.Equal(t, distributionKP.Address(), defaultSigService.DistributionAccount()) +} + +func Test_DefaultSignatureService_NetworkPassphrase(t *testing.T) { + // test with testnet passphrase + sigService := &DefaultSignatureService{networkPassphrase: network.TestNetworkPassphrase} + assert.Equal(t, network.TestNetworkPassphrase, sigService.NetworkPassphrase()) + + // test with public network passphrase, to make sure it's changing accordingly + sigService = &DefaultSignatureService{networkPassphrase: network.PublicNetworkPassphrase} + assert.Equal(t, network.PublicNetworkPassphrase, sigService.NetworkPassphrase()) +} + +func Test_DefaultSignatureService_getKPsForAccounts(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + // create distribution account + distributionKP, err := keypair.Random() + require.NoError(t, err) + + // create default encrypter + encrypter := &utils.DefaultPrivateKeyEncrypter{} + encrypterPass := distributionKP.Seed() + + // create channel accounts in the DB + channelAccounts := store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 2) + chAccKP1, err := keypair.ParseFull(channelAccounts[0].PrivateKey) + require.NoError(t, err) + chAccKP2, err := keypair.ParseFull(channelAccounts[1].PrivateKey) + require.NoError(t, err) + + // create channel account that's not in the DB + nonExistentChannelAccountKP, err := keypair.Random() + require.NoError(t, err) + + // create channel account with encrypted private key + decryptableKeyChAccKP, err := keypair.Random() + require.NoError(t, err) + decryptableKeyChAccKPSeed, err := encrypter.Encrypt(decryptableKeyChAccKP.Seed(), encrypterPass) + require.NoError(t, err) + err = chAccountStore.Insert(ctx, chAccountStore.DBConnectionPool, decryptableKeyChAccKP.Address(), decryptableKeyChAccKPSeed) + require.NoError(t, err) + + // create Channel account with private key encrypted by a different passphrase + undecryptableKeyChAccKP, err := keypair.Random() + require.NoError(t, err) + undecryptableKeyChAccKPSeed, err := encrypter.Encrypt(undecryptableKeyChAccKP.Seed(), keypair.MustRandom().Seed()) + require.NoError(t, err) + err = chAccountStore.Insert(ctx, chAccountStore.DBConnectionPool, undecryptableKeyChAccKP.Address(), undecryptableKeyChAccKPSeed) + require.NoError(t, err) + + // create default signature service + defaultSigService, err := NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, encrypter, encrypterPass) + require.NoError(t, err) + + testCases := []struct { + name string + accounts []string + wantErrContains string + wantKeypairs []*keypair.Full + }{ + { + name: "return an error if no accounts are passed", + accounts: []string{}, + wantErrContains: "no accounts provided", + }, + { + name: "return an error if one of the accounts is empty", + accounts: []string{""}, + wantErrContains: "account 0 is empty", + }, + { + name: "return an error if one of the accounts doesn't exist in the database", + accounts: []string{nonExistentChannelAccountKP.Address()}, + wantErrContains: store.ErrRecordNotFound.Error(), + }, + { + name: "πŸŽ‰ Successfully returns the distribution KP", + accounts: []string{distributionKP.Address()}, + wantKeypairs: []*keypair.Full{distributionKP}, + }, + { + name: "πŸŽ‰ Successfully one result if there are repeated values in the input array", + accounts: []string{distributionKP.Address(), distributionKP.Address(), chAccKP1.Address(), chAccKP1.Address()}, + wantKeypairs: []*keypair.Full{distributionKP, chAccKP1}, + }, + { + name: "πŸŽ‰ Successfully returns distribution and channel accounts KPs, for unencrypted seeds", + accounts: []string{distributionKP.Address(), chAccKP1.Address(), chAccKP2.Address()}, + wantKeypairs: []*keypair.Full{distributionKP, chAccKP1, chAccKP2}, + }, + { + name: "πŸŽ‰ Successfully returns distribution and channel accounts KPs, with 1 encrypted seed", + accounts: []string{distributionKP.Address(), chAccKP1.Address(), chAccKP2.Address(), decryptableKeyChAccKP.Address()}, + wantKeypairs: []*keypair.Full{distributionKP, chAccKP1, chAccKP2, decryptableKeyChAccKP}, + }, + { + name: "return an error if one of the encrypted seeds cannot be decrypted with the expected passphrase", + accounts: []string{undecryptableKeyChAccKP.Address()}, + wantErrContains: "cannot decrypt private key: cipher: message authentication failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + kps, err := defaultSigService.getKPsForAccounts(ctx, tc.accounts...) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, kps) + } else { + require.NoError(t, err) + assert.Len(t, kps, len(tc.wantKeypairs)) + assert.Equal(t, tc.wantKeypairs, kps) + } + }) + } +} + +func Test_DefaultSignatureService_SignStellarTransaction(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + ctx := context.Background() + + // create channel accounts in the DB + channelAccounts := store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1) + chAccKP, err := keypair.ParseFull(channelAccounts[0].PrivateKey) + require.NoError(t, err) + + // create distribution account + distributionKP, err := keypair.Random() + require.NoError(t, err) + + defaultSigService, err := NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, &utils.DefaultPrivateKeyEncrypter{}, distributionKP.Seed()) + require.NoError(t, err) + + // create stellar transaction + chSourceAccount := txnbuild.NewSimpleAccount(chAccKP.Address(), int64(9605939170639897)) + stellarTx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &chSourceAccount, + IncrementSequenceNum: true, + Operations: []txnbuild.Operation{&txnbuild.Payment{ + Destination: "GCCOBXW2XQNUSL467IEILE6MMCNRR66SSVL4YQADUNYYNUVREF3FIV2Z", + Amount: "10", + Asset: txnbuild.NativeAsset{}, + SourceAccount: distributionKP.Address(), + }}, + BaseFee: txnbuild.MinBaseFee, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(60)}, + }, + ) + require.NoError(t, err) + + wantSignedStellarTx, err := stellarTx.Sign(network.TestNetworkPassphrase, distributionKP, chAccKP) + require.NoError(t, err) + + testCases := []struct { + name string + stellarTx *txnbuild.Transaction + accounts []string + wantErrContains string + wantSignedStellarTx *txnbuild.Transaction + }{ + { + name: "return an error if stellar transaction is nil", + stellarTx: nil, + accounts: []string{}, + wantErrContains: "stellarTx cannot be nil", + }, + { + name: "return an error if no accounts are passed", + stellarTx: stellarTx, + accounts: []string{}, + wantErrContains: "no accounts provided", + }, + { + name: "πŸŽ‰ Successfully sign transaction when all incoming addresses are correct", + stellarTx: stellarTx, + accounts: []string{distributionKP.Address(), chAccKP.Address()}, + wantSignedStellarTx: wantSignedStellarTx, + }, + { + name: "πŸŽ‰ Successfully sign transaction when all some address are repeated", + stellarTx: stellarTx, + accounts: []string{distributionKP.Address(), chAccKP.Address(), chAccKP.Address()}, + wantSignedStellarTx: wantSignedStellarTx, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotSignedStellarTx, err := defaultSigService.SignStellarTransaction(ctx, tc.stellarTx, tc.accounts...) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, gotSignedStellarTx) + } else { + require.NoError(t, err) + assert.ElementsMatch(t, tc.wantSignedStellarTx.Signatures(), gotSignedStellarTx.Signatures()) + } + }) + } +} + +func Test_DefaultSignatureService_SignFeeBumpStellarTransaction(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + ctx := context.Background() + + // create channel accounts in the DB + channelAccounts := store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1) + chAccKP, err := keypair.ParseFull(channelAccounts[0].PrivateKey) + require.NoError(t, err) + + // create distribution account + distributionKP, err := keypair.Random() + require.NoError(t, err) + + defaultSigService, err := NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, &utils.DefaultPrivateKeyEncrypter{}, distributionKP.Seed()) + require.NoError(t, err) + + // create stellar transaction + chSourceAccount := txnbuild.NewSimpleAccount(chAccKP.Address(), int64(9605939170639897)) + stellarTx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &chSourceAccount, + IncrementSequenceNum: true, + Operations: []txnbuild.Operation{&txnbuild.Payment{ + Destination: "GCCOBXW2XQNUSL467IEILE6MMCNRR66SSVL4YQADUNYYNUVREF3FIV2Z", + Amount: "10", + Asset: txnbuild.NativeAsset{}, + SourceAccount: distributionKP.Address(), + }}, + BaseFee: txnbuild.MinBaseFee, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(60)}, + }, + ) + require.NoError(t, err) + signedStellarTx, err := stellarTx.Sign(network.TestNetworkPassphrase, distributionKP, chAccKP) + require.NoError(t, err) + + feeBumpStellarTx, err := txnbuild.NewFeeBumpTransaction( + txnbuild.FeeBumpTransactionParams{ + Inner: signedStellarTx, + FeeAccount: distributionKP.Address(), + BaseFee: txnbuild.MinBaseFee, + }, + ) + require.NoError(t, err) + + wantSignedFeeBumpStellarTx, err := feeBumpStellarTx.Sign(network.TestNetworkPassphrase, distributionKP) + assert.NoError(t, err) + + testCases := []struct { + name string + feeBumpStellarTx *txnbuild.FeeBumpTransaction + accounts []string + wantErrContains string + wantSignedFeeBumpStellarTx *txnbuild.FeeBumpTransaction + }{ + { + name: "return an error if stellar transaction is nil", + feeBumpStellarTx: nil, + accounts: []string{}, + wantErrContains: "stellarTx cannot be nil", + }, + { + name: "return an error if no accounts are passed", + feeBumpStellarTx: feeBumpStellarTx, + accounts: []string{}, + wantErrContains: "no accounts provided", + }, + { + name: "πŸŽ‰ Successfully sign transaction when all incoming addresses are correct", + feeBumpStellarTx: feeBumpStellarTx, + accounts: []string{distributionKP.Address()}, + wantSignedFeeBumpStellarTx: wantSignedFeeBumpStellarTx, + }, + { + name: "πŸŽ‰ Successfully sign transaction when all some address are repeated", + feeBumpStellarTx: feeBumpStellarTx, + accounts: []string{distributionKP.Address(), distributionKP.Address()}, + wantSignedFeeBumpStellarTx: wantSignedFeeBumpStellarTx, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotSignedFeeBumpStellarTx, err := defaultSigService.SignFeeBumpStellarTransaction(ctx, tc.feeBumpStellarTx, tc.accounts...) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, gotSignedFeeBumpStellarTx) + } else { + require.NoError(t, err) + assert.ElementsMatch(t, tc.wantSignedFeeBumpStellarTx.Signatures(), gotSignedFeeBumpStellarTx.Signatures()) + } + }) + } +} + +func Test_DefaultSignatureService_BatchInsert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + distributionKP, err := keypair.Random() + require.NoError(t, err) + + signerKP1 := keypair.MustRandom() + signerKP2 := keypair.MustRandom() + + testCase := []struct { + name string + shouldEncryptSeed bool + kps []*keypair.Full + wantErrContains string + }{ + { + name: "if KPs is empty, return an error", + wantErrContains: "no keypairs provided", + }, + { + name: "πŸŽ‰ successfully bulk insert without encryption", + shouldEncryptSeed: false, + kps: []*keypair.Full{signerKP1, signerKP2}, + }, + { + name: "πŸŽ‰ successfully bulk insert with encryption", + shouldEncryptSeed: true, + kps: []*keypair.Full{signerKP1, signerKP2}, + }, + } + + type comparableChAccount struct { + PublicKey string + PrivateKey string + } + + defaultEncrypter := &utils.DefaultPrivateKeyEncrypter{} + encrypterPass := distributionKP.Seed() + defaultSigService, err := NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, defaultEncrypter, encrypterPass) + require.NoError(t, err) + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + count, err := chAccountStore.Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, count, "this test should have started with 0 channel accounts") + + err = defaultSigService.BatchInsert(ctx, tc.kps, tc.shouldEncryptSeed, 0) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + require.NoError(t, err) + + allChAccounts, err := chAccountStore.GetAll(ctx, dbConnectionPool, math.MaxInt32, 0) + require.NoError(t, err) + assert.Equal(t, len(tc.kps), len(allChAccounts)) + + // compare the accounts + var allChAccountsComparable []comparableChAccount + for _, chAccount := range allChAccounts { + publicKey := chAccount.PublicKey + privateKey := chAccount.PrivateKey + + if tc.shouldEncryptSeed { + privateKey, err = defaultEncrypter.Decrypt(privateKey, encrypterPass) + require.NoError(t, err) + } + + allChAccountsComparable = append(allChAccountsComparable, comparableChAccount{ + PublicKey: publicKey, + PrivateKey: privateKey, + }) + } + + var tcChAccountsComparable []comparableChAccount + for _, kp := range tc.kps { + tcChAccountsComparable = append(tcChAccountsComparable, comparableChAccount{ + PublicKey: kp.Address(), + PrivateKey: kp.Seed(), + }) + } + assert.ElementsMatch(t, tcChAccountsComparable, allChAccountsComparable) + } + + store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_DefaultSignatureService_Delete(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccountStore := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + distributionKP, err := keypair.Random() + require.NoError(t, err) + defaultSigService, err := NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccountStore, &utils.PrivateKeyEncrypterMock{}, distributionKP.Seed()) + require.NoError(t, err) + + // at start: count=0 + count, err := chAccountStore.Count(ctx) + require.NoError(t, err) + assert.Equal(t, 0, count) + + // create 2 accounts: count=0->2 + channelAccounts := store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 2) + count, err = chAccountStore.Count(ctx) + require.NoError(t, err) + assert.Equal(t, 2, count) + + currLedgerNumber := 0 + lockUntilLedgerNumber := 10 + for _, chAcc := range channelAccounts { + _, err = chAccountStore.Lock(ctx, chAccountStore.DBConnectionPool, chAcc.PublicKey, int32(currLedgerNumber), int32(lockUntilLedgerNumber)) + require.NoError(t, err) + } + + // delete one account: count=2->1 + err = defaultSigService.Delete(ctx, channelAccounts[0].PublicKey, lockUntilLedgerNumber) + require.NoError(t, err) + count, err = chAccountStore.Count(ctx) + require.NoError(t, err) + assert.Equal(t, 1, count) + + // delete another account: count=1->0 + err = defaultSigService.Delete(ctx, channelAccounts[1].PublicKey, lockUntilLedgerNumber) + require.NoError(t, err) + count, err = chAccountStore.Count(ctx) + require.NoError(t, err) + assert.Equal(t, 0, count) + + // delete non-existing account: error expected + err = defaultSigService.Delete(ctx, "non-existent-account", 0) + require.Error(t, err) + assert.ErrorIs(t, err, store.ErrRecordNotFound) +} diff --git a/internal/transactionsubmission/engine/submitter_engine.go b/internal/transactionsubmission/engine/submitter_engine.go new file mode 100644 index 000000000..4177ef817 --- /dev/null +++ b/internal/transactionsubmission/engine/submitter_engine.go @@ -0,0 +1,26 @@ +package engine + +import ( + "fmt" + + "github.com/stellar/go/clients/horizonclient" +) + +// SubmitterEngine aggregates the dependencies that are shared between all Submitter instances, such as the Ledger +// number tracker. +type SubmitterEngine struct { + HorizonClient horizonclient.ClientInterface + LedgerNumberTracker +} + +func NewSubmitterEngine(hClient horizonclient.ClientInterface) (*SubmitterEngine, error) { + ledgerNumberTracker, err := NewLedgerNumberTracker(hClient) + if err != nil { + return nil, fmt.Errorf("creating ledger keeper: %w", err) + } + + return &SubmitterEngine{ + HorizonClient: hClient, + LedgerNumberTracker: ledgerNumberTracker, + }, nil +} diff --git a/internal/transactionsubmission/engine/submitter_engine_test.go b/internal/transactionsubmission/engine/submitter_engine_test.go new file mode 100644 index 000000000..b32f90c7f --- /dev/null +++ b/internal/transactionsubmission/engine/submitter_engine_test.go @@ -0,0 +1,51 @@ +package engine + +import ( + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewSubmitterEngine(t *testing.T) { + mockHorizonClient := &horizonclient.MockClient{} + + testCases := []struct { + name string + hClient horizonclient.ClientInterface + wantErrContains string + wantResult *SubmitterEngine + }{ + { + name: "returns an error if the horizon client is nil", + hClient: nil, + wantErrContains: "creating ledger keeper: horizon client cannot be nil", + }, + { + name: "πŸŽ‰ successfully provides new SubmitterEngine", + hClient: mockHorizonClient, + wantResult: &SubmitterEngine{ + HorizonClient: mockHorizonClient, + LedgerNumberTracker: &DefaultLedgerNumberTracker{hClient: mockHorizonClient, maxLedgerAge: MaxLedgerAge}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + submitterEngine, err := NewSubmitterEngine(tc.hClient) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, submitterEngine) + } else { + require.NoError(t, err) + assert.NotNil(t, submitterEngine) + assert.Equal(t, tc.wantResult, submitterEngine) + } + }) + } + + mockHorizonClient.AssertExpectations(t) +} diff --git a/internal/transactionsubmission/engine/tx_processing_limiter.go b/internal/transactionsubmission/engine/tx_processing_limiter.go new file mode 100644 index 000000000..fc2ae67f1 --- /dev/null +++ b/internal/transactionsubmission/engine/tx_processing_limiter.go @@ -0,0 +1,73 @@ +package engine + +import ( + "sync" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +const ( + defaultBundlesSelectionLimit = 8 + indeterminateResponsesToleranceLimit = 10 + minutesInWindow = 3 +) + +// TransactionProcessingLimiter is utilized by the manager and transaction worker to share metadata about and adjust +// the rate at which tss processes transactions based on responses from Horizon. +type TransactionProcessingLimiter struct { + CurrNumChannelAccounts int + IndeterminateResponsesCounter int + CounterLastUpdated time.Time + limitValue int + mutex sync.Mutex +} + +func NewTransactionProcessingLimiter(limit int) *TransactionProcessingLimiter { + if limit < 0 { + limit = defaultBundlesSelectionLimit + } + + return &TransactionProcessingLimiter{ + CurrNumChannelAccounts: limit, + IndeterminateResponsesCounter: 0, + CounterLastUpdated: time.Now(), + limitValue: limit, + } +} + +// AdjustLimitIfNeeded re-establishes the transaction processing limit based on how many transactions result in +// - `504`, 429`, `400` - tx_insufficient_fee` which are indicators for network congestion causing a cascade of further +// transaction failures and need for retries. +func (tpl *TransactionProcessingLimiter) AdjustLimitIfNeeded(hErr *utils.HorizonErrorWrapper) { + tpl.mutex.Lock() + defer tpl.mutex.Unlock() + + if !(hErr.IsRateLimit() || hErr.IsGatewayTimeout() || hErr.IsTxInsufficientFee()) { + return + } + + tpl.IndeterminateResponsesCounter++ + // We can tweek the following values as needed, and maybe add additional functionality to + // dynamically determine values for the default selection limit rather than using the default harcoded values + if tpl.IndeterminateResponsesCounter >= indeterminateResponsesToleranceLimit { + tpl.limitValue = defaultBundlesSelectionLimit + tpl.CounterLastUpdated = time.Now() + } +} + +// LimitValue resets the necessary counter-related values when the current time is well outside the fixed +// window of the last refresh, and serves as a getter for the `limitValue` field. +func (tpl *TransactionProcessingLimiter) LimitValue() int { + tpl.mutex.Lock() + defer tpl.mutex.Unlock() + // refresh counter on a fixed window basis + now := time.Now() + if now.After(tpl.CounterLastUpdated.Add(minutesInWindow * time.Minute)) { + tpl.IndeterminateResponsesCounter = 0 + tpl.CounterLastUpdated = now + tpl.limitValue = tpl.CurrNumChannelAccounts + } + + return tpl.limitValue +} diff --git a/internal/transactionsubmission/engine/tx_processing_limiter_test.go b/internal/transactionsubmission/engine/tx_processing_limiter_test.go new file mode 100644 index 000000000..3557bf22f --- /dev/null +++ b/internal/transactionsubmission/engine/tx_processing_limiter_test.go @@ -0,0 +1,145 @@ +package engine + +import ( + "net/http" + "testing" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/render/problem" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + "github.com/stretchr/testify/assert" +) + +func Test_TxProcessingLimiter_AdjustLimitIfNeeded(t *testing.T) { + currNumChannelAccounts := 50 + + testCases := []struct { + name string + hErr *utils.HorizonErrorWrapper + wantResult *TransactionProcessingLimiter + }{ + { + name: "adjusts limit if the horizon client error is too_many_requests", + hErr: utils.NewHorizonErrorWrapper( + &horizonclient.Error{ + Problem: problem.P{Status: http.StatusTooManyRequests}, + }, + ), + wantResult: &TransactionProcessingLimiter{ + limitValue: defaultBundlesSelectionLimit, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit, + }, + }, + { + name: "adjusts limit if the horizon client error is gateway_timeout", + hErr: utils.NewHorizonErrorWrapper( + &horizonclient.Error{ + Problem: problem.P{Status: http.StatusGatewayTimeout}, + }, + ), + wantResult: &TransactionProcessingLimiter{ + limitValue: defaultBundlesSelectionLimit, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit, + }, + }, + { + name: "adjusts limit if one of the operation error is tx_insufficient_fee", + hErr: utils.NewHorizonErrorWrapper( + &horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_insufficient_fee", + }, + }, + }, + }, + ), + wantResult: &TransactionProcessingLimiter{ + limitValue: defaultBundlesSelectionLimit, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit, + }, + }, + { + name: "no adjustment for determinate error", + hErr: utils.NewHorizonErrorWrapper( + &horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_bad_auth", + }, + }, + }, + }, + ), + wantResult: &TransactionProcessingLimiter{ + limitValue: currNumChannelAccounts, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit - 1, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + txProcessingLimiter := &TransactionProcessingLimiter{ + CurrNumChannelAccounts: currNumChannelAccounts, + limitValue: currNumChannelAccounts, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit - 1, + CounterLastUpdated: time.Now(), + } + txProcessingLimiter.AdjustLimitIfNeeded(tc.hErr) + + assert.Equal(t, txProcessingLimiter.limitValue, tc.wantResult.limitValue) + assert.Equal(t, txProcessingLimiter.IndeterminateResponsesCounter, tc.wantResult.IndeterminateResponsesCounter) + }) + } +} + +func Test_TxProcessingLimiter_LimitValue(t *testing.T) { + initialLimitValue := 100 + currNumChannelAccounts := 50 + + testCases := []struct { + name string + wait func(tpl *TransactionProcessingLimiter) + wantResult *TransactionProcessingLimiter + }{ + { + name: "no change when the time is before current window is complete", + wait: func(tpl *TransactionProcessingLimiter) {}, + wantResult: &TransactionProcessingLimiter{ + limitValue: initialLimitValue, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit - 1, + }, + }, + { + name: "change when the time is after current window is complete", + wait: func(tpl *TransactionProcessingLimiter) { + tpl.CounterLastUpdated = tpl.CounterLastUpdated.Add(-10 * time.Minute) + }, + wantResult: &TransactionProcessingLimiter{ + limitValue: currNumChannelAccounts, + IndeterminateResponsesCounter: 0, + }, + }, + } + + for _, tc := range testCases { + txProcessingLimiter := &TransactionProcessingLimiter{ + CurrNumChannelAccounts: currNumChannelAccounts, + limitValue: initialLimitValue, + IndeterminateResponsesCounter: indeterminateResponsesToleranceLimit - 1, + CounterLastUpdated: time.Now(), + } + tc.wait(txProcessingLimiter) + lv := txProcessingLimiter.LimitValue() + + assert.Equal(t, tc.wantResult.limitValue, txProcessingLimiter.limitValue) + assert.Equal(t, tc.wantResult.IndeterminateResponsesCounter, txProcessingLimiter.IndeterminateResponsesCounter) + assert.Equal(t, tc.wantResult.limitValue, lv) + } +} diff --git a/internal/transactionsubmission/horizon.go b/internal/transactionsubmission/horizon.go new file mode 100644 index 000000000..ab0615b00 --- /dev/null +++ b/internal/transactionsubmission/horizon.go @@ -0,0 +1,217 @@ +package transactionsubmission + +import ( + "context" + "errors" + "fmt" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/keypair" + "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + "golang.org/x/exp/slices" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +var ErrInvalidNumOfChannelAccountsToCreate = errors.New("invalid number of channel accounts to create") + +// MaximumCreateAccountOperationsPerStellarTx is the max number of sponsored accounts we can create in one transaction +// due to the signature limit. +const MaximumCreateAccountOperationsPerStellarTx = 19 + +// MaxNumberOfChannelAccounts is the limit for the number of accounts tx submission service should manage. +const MaxNumberOfChannelAccounts = 1000 + +// MinNumberOfChannelAccounts is the minimum number of accounts tx submission service should manage. +const MinNumberOfChannelAccounts = 1 + +// DefaultRevokeSponsorshipReserveAmount is the amount of the native asset that the sponsoring account will send +// to the sponsored account to cover the reserve that is needed to for revoking account sponsorship. +// The amount will be send back to the sponsoring account once the sponsored account is deleted onchain. +const DefaultRevokeSponsorshipReserveAmount = "1.5" + +// CreateChannelAccountsOnChain will create up to 19 accounts per Transaction due to the 20 signatures per tx limit This +// is also a good opportunity to periodically write the generated accounts to persistent storage if generating large +// amounts of channel accounts. +func CreateChannelAccountsOnChain(ctx context.Context, horizonClient horizonclient.ClientInterface, numOfChanAccToCreate int, maxBaseFee int, shouldEncryptSeed bool, sigService engine.SignatureService, currLedgerNumber int) (newAccountAddresses []string, err error) { + defer func() { + // If we failed to create the accounts, we should delete the accounts that were added to the signature service. + if err != nil && sigService != nil { + cloneOfNewAccountAddresses := slices.Clone(newAccountAddresses) + for _, accountAddress := range cloneOfNewAccountAddresses { + if accountAddress == sigService.DistributionAccount() { + continue + } + deleteErr := sigService.Delete(ctx, accountAddress, currLedgerNumber+engine.IncrementForMaxLedgerBounds) + if deleteErr != nil { + log.Ctx(ctx).Errorf("failed to delete channel account %s: %v", accountAddress, deleteErr) + } + } + newAccountAddresses = nil + } + }() + + if numOfChanAccToCreate > MaximumCreateAccountOperationsPerStellarTx { + return nil, fmt.Errorf("cannot create more than %d channel accounts", MaximumCreateAccountOperationsPerStellarTx) + } + + if numOfChanAccToCreate <= 0 { + return nil, ErrInvalidNumOfChannelAccountsToCreate + } + + rootAccount, err := horizonClient.AccountDetail(horizonclient.AccountRequest{ + AccountID: sigService.DistributionAccount(), + }) + if err != nil { + return nil, fmt.Errorf("failed to retrieve root account: %w", err) + } + + var sponsoredCreateAccountOps []txnbuild.Operation + + kpsToCreate := []*keypair.Full{} + + // Prepare Stellar operations to create the sponsored channel accounts + for i := 0; i < numOfChanAccToCreate; i++ { + // generate random keypair for this channel account + var channelAccountKP *keypair.Full + channelAccountKP, err = keypair.Random() + if err != nil { + return nil, fmt.Errorf("failed to generate keypair: %w", err) + } + log.Ctx(ctx).Infof("creating sponsored stellar account with address: %s", channelAccountKP.Address()) + + sponsoredCreateAccountOps = append( + sponsoredCreateAccountOps, + + // add sponsor operations for this account + &txnbuild.BeginSponsoringFutureReserves{ + SponsoredID: channelAccountKP.Address(), + }, + &txnbuild.CreateAccount{ + Destination: channelAccountKP.Address(), + Amount: "0", + }, + &txnbuild.EndSponsoringFutureReserves{ + SourceAccount: channelAccountKP.Address(), + }, + ) + + // append this channel account to the list of signers + kpsToCreate = append(kpsToCreate, channelAccountKP) + newAccountAddresses = append(newAccountAddresses, channelAccountKP.Address()) + } + + err = sigService.BatchInsert(ctx, kpsToCreate, shouldEncryptSeed, currLedgerNumber) + if err != nil { + return nil, fmt.Errorf("failed to insert channel accounts into signature service: %w", err) + } + + // create a new transaction with the account creation/sponsorship operations + tx, err := txnbuild.NewTransaction(txnbuild.TransactionParams{ + SourceAccount: &rootAccount, + IncrementSequenceNum: true, + Operations: sponsoredCreateAccountOps, + BaseFee: int64(maxBaseFee), + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(15), + }, + }) + if err != nil { + return nil, fmt.Errorf("creating transaction for channel account creation: %w", err) + } + + // sign the transaction + signers := append([]string{sigService.DistributionAccount()}, newAccountAddresses...) + tx, err = sigService.SignStellarTransaction(ctx, tx, signers...) + if err != nil { + return newAccountAddresses, fmt.Errorf("signing transaction: %w", err) + } + + _, err = horizonClient.SubmitTransactionWithOptions(tx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}) + if hError := horizonclient.GetError(err); hError != nil { + hErrorStr := utils.GetHorizonErrorString(*hError) + return newAccountAddresses, fmt.Errorf("creating sponsored channel accounts: %v", hErrorStr) + } else if err != nil { + return newAccountAddresses, fmt.Errorf("creating sponsored channel accounts: %w", err) + } + log.Ctx(ctx).Infof("πŸŽ‰ Successfully created %d sponsored channel accounts", len(newAccountAddresses)) + + return newAccountAddresses, nil +} + +// DeleteChannelAccountOnChain creates, signs, and broadcasts a transaction to delete a channel account onchain. +func DeleteChannelAccountOnChain( + ctx context.Context, + horizonClient horizonclient.ClientInterface, + chAccAddress string, + maxBaseFee int64, + sigService engine.SignatureService, + lockedUntilLedgerNumber int, +) error { + distributionAccount := sigService.DistributionAccount() + rootAccount, err := horizonClient.AccountDetail(horizonclient.AccountRequest{ + AccountID: distributionAccount, + }) + if err != nil { + return fmt.Errorf("retrieving root account from distribution seed: %w", err) + } + + // TODO: Currently, this transaction deletes a single sponsored account onchain, we may want to + // attempt to delete more accounts per tx in the future up to the limit of operations and + // signatures a single tx will allow + tx, err := txnbuild.NewTransaction(txnbuild.TransactionParams{ + SourceAccount: &rootAccount, + IncrementSequenceNum: true, + Operations: []txnbuild.Operation{ + &txnbuild.Payment{ + SourceAccount: rootAccount.AccountID, + Destination: chAccAddress, + Amount: DefaultRevokeSponsorshipReserveAmount, + Asset: txnbuild.NativeAsset{}, + }, + &txnbuild.RevokeSponsorship{ + SponsorshipType: txnbuild.RevokeSponsorshipTypeAccount, + Account: &chAccAddress, + }, + &txnbuild.AccountMerge{ + Destination: distributionAccount, + SourceAccount: chAccAddress, + }, + }, + BaseFee: maxBaseFee, + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(15), + }, + }) + if err != nil { + return fmt.Errorf( + "constructing remove channel account transaction for account %s: %w", + chAccAddress, + err, + ) + } + + // the root account authorizes the sponsorship revocation, while the channel account authorizes + // merging into the distribution account + tx, err = sigService.SignStellarTransaction(ctx, tx, sigService.DistributionAccount(), chAccAddress) + if err != nil { + return fmt.Errorf("signing remove account transaction for account %s: %w", chAccAddress, err) + } + + _, err = horizonClient.SubmitTransactionWithOptions( + tx, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ) + if err != nil { + return fmt.Errorf("submitting remove account transaction to the network for account %s: %w", chAccAddress, err) + } + + err = sigService.Delete(ctx, chAccAddress, lockedUntilLedgerNumber) + if err != nil { + return fmt.Errorf("deleting channel account %s from the store: %w", chAccAddress, err) + } + + return nil +} diff --git a/internal/transactionsubmission/horizon_test.go b/internal/transactionsubmission/horizon_test.go new file mode 100644 index 000000000..6cacc0d81 --- /dev/null +++ b/internal/transactionsubmission/horizon_test.go @@ -0,0 +1,346 @@ +package transactionsubmission + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/render/problem" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + engineMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_CreateChannelAccountsOnChain(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + horizonClientMock := &horizonclient.MockClient{} + privateKeyEncrypterMock := &utils.PrivateKeyEncrypterMock{} + currLedgerNumber := 100 + ctx := context.Background() + chAccModel := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + distributionKP := keypair.MustRandom() + encrypterPass := distributionKP.Seed() + sigService, err := engine.NewDefaultSignatureService(network.TestNetworkPassphrase, dbConnectionPool, distributionKP.Seed(), chAccModel, privateKeyEncrypterMock, encrypterPass) + require.NoError(t, err) + + testCases := []struct { + name string + numOfChanAccToCreate int + shouldEncryptSeed bool + prepareMocksFn func() + wantErrContains string + }{ + { + name: "returns error when 'numOfChanAccToCreate > MaximumCreateAccountOperationsPerStellarTx'", + numOfChanAccToCreate: MaximumCreateAccountOperationsPerStellarTx + 1, + wantErrContains: "cannot create more than 19 channel accounts", + }, + { + name: "returns error when numOfChanAccToCreate=0", + numOfChanAccToCreate: 0, + wantErrContains: ErrInvalidNumOfChannelAccountsToCreate.Error(), + }, + { + name: "returns error when numOfChanAccToCreate=-2", + numOfChanAccToCreate: -2, + wantErrContains: ErrInvalidNumOfChannelAccountsToCreate.Error(), + }, + { + name: "returns error when HorizonClient fails getting AccountDetails", + numOfChanAccToCreate: 2, + prepareMocksFn: func() { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: sigService.DistributionAccount()}). + Return(horizon.Account{}, horizonclient.Error{ + Problem: problem.NotFound, + }). + Once() + }, + wantErrContains: `failed to retrieve root account: horizon error: "Resource Missing" - check horizon.Error.Problem for more information`, + }, + { + name: "returns error when fails encrypting private key", + numOfChanAccToCreate: 2, + shouldEncryptSeed: true, + prepareMocksFn: func() { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: sigService.DistributionAccount()}). + Return(horizon.Account{ + AccountID: sigService.DistributionAccount(), + Sequence: 1, + }, nil). + Once() + privateKeyEncrypterMock. + On("Encrypt", mock.AnythingOfType("string"), encrypterPass). + Return("", errors.New("unexpected error")). + Once() + }, + wantErrContains: "encrypting channel account private key: unexpected error", + }, + { + name: "returns error when fails submitting transaction to horizon", + numOfChanAccToCreate: 2, + prepareMocksFn: func() { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: sigService.DistributionAccount()}). + Return(horizon.Account{ + AccountID: sigService.DistributionAccount(), + Sequence: 1, + }, nil). + Once(). + On("SubmitTransactionWithOptions", mock.AnythingOfType("*txnbuild.Transaction"), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, horizonclient.Error{ + Problem: problem.P{ + Type: "https://stellar.org/horizon-errors/timeout", + Title: "Timeout", + Detail: "Foo bar detail", + Status: http.StatusRequestTimeout, + Extras: map[string]interface{}{"foo": "bar"}, + }, + }). + Once() + }, + wantErrContains: "creating sponsored channel accounts: Type: https://stellar.org/horizon-errors/timeout, Title: Timeout, Status: 408, Detail: Foo bar detail, Extras: map[foo:bar]", + }, + { + name: "πŸŽ‰ successfully creates channel accounts on-chain (UNENCRYPTED)", + numOfChanAccToCreate: 2, + shouldEncryptSeed: false, + prepareMocksFn: func() { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: distributionKP.Address()}). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 1, + }, nil). + Once(). + On("SubmitTransactionWithOptions", mock.AnythingOfType("*txnbuild.Transaction"), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + }, + }, + { + name: "πŸŽ‰ successfully creates channel accounts on-chain (ENCRYPTED)", + numOfChanAccToCreate: 3, + shouldEncryptSeed: true, + prepareMocksFn: func() { + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: distributionKP.Address()}). + Return(horizon.Account{ + AccountID: distributionKP.Address(), + Sequence: 1, + }, nil). + Once(). + On("SubmitTransactionWithOptions", mock.AnythingOfType("*txnbuild.Transaction"), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + privateKeyEncrypterMock. + On("Encrypt", mock.AnythingOfType("string"), encrypterPass).Return("encryptedkey", nil).Times(3). + On("Decrypt", mock.AnythingOfType("string"), encrypterPass).Return(keypair.MustRandom().Seed(), nil).Times(3) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + count, err := chAccModel.Count(ctx) + require.NoError(t, err) + assert.Equal(t, 0, count) + + if tc.prepareMocksFn != nil { + tc.prepareMocksFn() + } + + channelAccountAddresses, err := CreateChannelAccountsOnChain(ctx, horizonClientMock, tc.numOfChanAccToCreate, txnbuild.MinBaseFee, tc.shouldEncryptSeed, sigService, currLedgerNumber) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.Empty(t, channelAccountAddresses) + assert.ErrorContains(t, err, tc.wantErrContains) + + count, err = chAccModel.Count(ctx) + require.NoError(t, err) + assert.Equal(t, 0, count) + } else { + require.NoError(t, err) + assert.Len(t, channelAccountAddresses, tc.numOfChanAccToCreate) + + count, err = chAccModel.Count(ctx) + require.NoError(t, err) + assert.Equal(t, tc.numOfChanAccToCreate, count) + + allChAcc, err := chAccModel.GetAll(ctx, dbConnectionPool, math.MaxInt32, 100) + require.NoError(t, err) + assert.Len(t, allChAcc, tc.numOfChanAccToCreate) + + if !tc.shouldEncryptSeed { + for _, chAcc := range allChAcc { + assert.True(t, strkey.IsValidEd25519SecretSeed(chAcc.PrivateKey)) + } + } else { + for _, chAcc := range allChAcc { + assert.False(t, strkey.IsValidEd25519SecretSeed(chAcc.PrivateKey)) + } + } + } + + store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } + + horizonClientMock.AssertExpectations(t) + privateKeyEncrypterMock.AssertExpectations(t) +} + +func Test_DeleteChannelAccountOnChain(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + horizonClientMock := &horizonclient.MockClient{} + privateKeyEncrypterMock := &utils.PrivateKeyEncrypterMock{} + ctx := context.Background() + + distributionKP := keypair.MustRandom() + distributionAddress := distributionKP.Address() + mockSigService := &engineMocks.MockSignatureService{} + require.NoError(t, err) + + chAccAddress := keypair.MustRandom().Address() + currLedger := 100 + + testCases := []struct { + name string + prepareMocksFn func() + chAccAddressToDelete string + wantErrContains string + }{ + { + name: "returns error when HorizonClient fails getting AccountDetails", + prepareMocksFn: func() { + mockSigService.On("DistributionAccount").Return(distributionAddress).Once() + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: distributionAddress}). + Return(horizon.Account{}, horizonclient.Error{ + Problem: problem.NotFound, + }). + Once() + }, + wantErrContains: `retrieving root account from distribution seed: horizon error: "Resource Missing" - check horizon.Error.Problem for more information`, + }, + { + name: "returns error when channel account doesnt exist", + chAccAddressToDelete: chAccAddress, + prepareMocksFn: func() { + mockSigService.On("DistributionAccount").Return(distributionAddress).Twice() + mockSigService. + On("SignStellarTransaction", ctx, mock.AnythingOfType("*txnbuild.Transaction"), distributionAddress, chAccAddress). + Return(nil, fmt.Errorf("signing remove account transaction for account")).Once() + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: distributionAddress}). + Return(horizon.Account{ + AccountID: distributionAddress, + Sequence: 1, + }, nil). + Once() + }, + wantErrContains: "signing remove account transaction for account", + }, + { + name: "returns error when fails submitting transaction to horizon", + chAccAddressToDelete: chAccAddress, + prepareMocksFn: func() { + mockSigService.On("DistributionAccount").Return(distributionAddress).Twice() + mockSigService. + On("SignStellarTransaction", ctx, mock.AnythingOfType("*txnbuild.Transaction"), distributionAddress, chAccAddress). + Return(&txnbuild.Transaction{}, nil).Once() + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: distributionAddress}). + Return(horizon.Account{ + AccountID: distributionAddress, + Sequence: 1, + }, nil). + Once() + horizonClientMock.On("SubmitTransactionWithOptions", mock.AnythingOfType("*txnbuild.Transaction"), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, horizonclient.Error{ + Problem: problem.P{ + Type: "https://stellar.org/horizon-errors/timeout", + Title: "Timeout", + Status: http.StatusRequestTimeout, + }, + }). + Once() + }, + wantErrContains: fmt.Sprintf( + `submitting remove account transaction to the network for account %s: horizon error: "Timeout" - check horizon.Error.Problem for more information`, + chAccAddress, + ), + }, + { + name: "πŸŽ‰ Successfully deletes channel account on chain and database", + chAccAddressToDelete: chAccAddress, + prepareMocksFn: func() { + mockSigService.On("DistributionAccount").Return(distributionAddress).Twice() + mockSigService. + On("SignStellarTransaction", ctx, mock.AnythingOfType("*txnbuild.Transaction"), distributionAddress, chAccAddress). + Return(&txnbuild.Transaction{}, nil).Once() + mockSigService.On("Delete", ctx, chAccAddress, currLedger).Return(nil).Once() + horizonClientMock. + On("AccountDetail", horizonclient.AccountRequest{AccountID: distributionAddress}). + Return(horizon.Account{ + AccountID: distributionAddress, + Sequence: 1, + }, nil). + Once() + horizonClientMock.On("SubmitTransactionWithOptions", mock.AnythingOfType("*txnbuild.Transaction"), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{}, nil). + Once() + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.prepareMocksFn != nil { + tc.prepareMocksFn() + } + + err = DeleteChannelAccountOnChain(ctx, horizonClientMock, tc.chAccAddressToDelete, txnbuild.MinBaseFee, mockSigService, currLedger) + + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + require.NoError(t, err) + } + + store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } + + mockSigService.AssertExpectations(t) + horizonClientMock.AssertExpectations(t) + privateKeyEncrypterMock.AssertExpectations(t) +} diff --git a/internal/transactionsubmission/manager.go b/internal/transactionsubmission/manager.go new file mode 100644 index 000000000..bc44c131d --- /dev/null +++ b/internal/transactionsubmission/manager.go @@ -0,0 +1,278 @@ +package transactionsubmission + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/network" + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +const serviceName = "Transaction Submission Service" + +type SubmitterOptions struct { + DatabaseDSN string + HorizonURL string + NetworkPassphrase string + DistributionSeed string + NumChannelAccounts int + QueuePollingInterval int + MaxBaseFee int + MonitorService monitor.MonitorServiceInterface + PrivateKeyEncrypter utils.PrivateKeyEncrypter + CrashTrackerClient crashtracker.CrashTrackerClient +} + +func (so *SubmitterOptions) validate() error { + if so.DatabaseDSN == "" { + return fmt.Errorf("database DSN cannot be empty") + } + + if so.MonitorService == nil { + return fmt.Errorf("monitor service cannot be nil") + } + + if so.HorizonURL == "" { + return fmt.Errorf("horizon url cannot be empty") + } + + if (so.NetworkPassphrase != network.TestNetworkPassphrase) && (so.NetworkPassphrase != network.PublicNetworkPassphrase) { + return fmt.Errorf("network passphrase %q is invalid", so.NetworkPassphrase) + } + + if so.PrivateKeyEncrypter == nil { + return fmt.Errorf("private key encrypter cannot be nil") + } + + if !strkey.IsValidEd25519SecretSeed(so.DistributionSeed) { + return fmt.Errorf("distribution seed is invalid") + } + + if so.NumChannelAccounts < MinNumberOfChannelAccounts || so.NumChannelAccounts > MaxNumberOfChannelAccounts { + return fmt.Errorf("num channel accounts must stay in the range from %d to %d", MinNumberOfChannelAccounts, MaxNumberOfChannelAccounts) + } + + if so.QueuePollingInterval < 6 { + return fmt.Errorf("queue polling interval must be greater than 6 seconds") + } + + if so.MaxBaseFee < txnbuild.MinBaseFee { + return fmt.Errorf("max base fee must be greater than or equal to %d", txnbuild.MinBaseFee) + } + + return nil +} + +type Manager struct { + // Data model: + dbConnectionPool db.DBConnectionPool + txModel *store.TransactionModel + chAccModel *store.ChannelAccountModel + chTxBundleModel *store.ChannelTransactionBundleModel + // job-related: + queueService defaultQueueService + txProcessingLimiter *engine.TransactionProcessingLimiter + // transaction submission: + engine *engine.SubmitterEngine + sigService engine.SignatureService + maxBaseFee int + // crash & metrics monitoring: + monitorService monitor.MonitorServiceInterface + crashTrackerClient crashtracker.CrashTrackerClient +} + +func NewManager(ctx context.Context, opts SubmitterOptions) (m *Manager, err error) { + // initialize crash tracker client + crashTrackerClient := opts.CrashTrackerClient + if opts.CrashTrackerClient == nil { + log.Ctx(ctx).Warn("crash tracker client not set, using DRY_RUN client") + crashTrackerClient, err = crashtracker.NewDryRunClient() + if err != nil { + return nil, fmt.Errorf("unable to initialize DRY_RUN crash tracker client: %w", err) + } + } + defer crashTrackerClient.FlushEvents(2 * time.Second) + defer crashTrackerClient.Recover() + + // validate options + err = opts.validate() + if err != nil { + return nil, fmt.Errorf("validating options: %w", err) + } + + // initialize database connection pool and the data models + dbConnectionPool, err := db.OpenDBConnectionPool(opts.DatabaseDSN) + if err != nil { + return nil, fmt.Errorf("opening db connection pool: %w", err) + } + defer func() { + // We only close the connection pool if the constructor finishes with an error. + // If we close the connection pool on successful cases, the manager will not be able to use it. + if err != nil { + dbConnectionPool.Close() + } + }() + txModel := store.NewTransactionModel(dbConnectionPool) + chAccModel := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + chTxBundleModel, err := store.NewChannelTransactionBundleModel(dbConnectionPool) + if err != nil { + return nil, fmt.Errorf("initializing channel transaction bundle model: %w", err) + } + + // initialize horizon client + horizonClient := &horizonclient.Client{ + HorizonURL: opts.HorizonURL, + HTTP: httpclient.DefaultClient(), + } + + // initialize default signature service + sigService, err := engine.NewDefaultSignatureService(opts.NetworkPassphrase, dbConnectionPool, opts.DistributionSeed, chAccModel, opts.PrivateKeyEncrypter, opts.DistributionSeed) + if err != nil { + return nil, fmt.Errorf("initializing default signature service: %w", err) + } + + // initialize SubmitterEngine + submitterEngine, err := engine.NewSubmitterEngine(horizonClient) + if err != nil { + return nil, fmt.Errorf("initializing submitter engine: %w", err) + } + + // validate if we have any channel accounts in the DB. + chAccCount, err := chAccModel.Count(ctx) + if err != nil { + return nil, fmt.Errorf("counting channel accounts: %w", err) + } + if chAccCount == 0 { + return nil, fmt.Errorf("no channel accounts found in the database, use the 'channel-accounts ensure' command to configure the number of accounts you want to use") + } + log.Ctx(ctx).Infof("Found '%d' channel accounts in the database...", chAccCount) + + if opts.NumChannelAccounts > chAccCount { + log.Ctx(ctx).Warnf("The number of channel accounts in the database is smaller than expected, (%d < %d)", chAccCount, opts.NumChannelAccounts) + } + + queueService := defaultQueueService{ + pollingInterval: time.Second * time.Duration(opts.QueuePollingInterval), + numChannelAccounts: opts.NumChannelAccounts, + } + + txProcessingLimiter := engine.NewTransactionProcessingLimiter(opts.NumChannelAccounts) + + return &Manager{ + dbConnectionPool: dbConnectionPool, + chAccModel: chAccModel, + txModel: txModel, + chTxBundleModel: chTxBundleModel, + + queueService: queueService, + txProcessingLimiter: txProcessingLimiter, + + engine: submitterEngine, + sigService: sigService, + maxBaseFee: opts.MaxBaseFee, + + crashTrackerClient: crashTrackerClient, + monitorService: opts.MonitorService, + }, nil +} + +// TODO: generalize the queue service in [SDP-748] to make it agnostic to databases. +type defaultQueueService struct { + pollingInterval time.Duration + numChannelAccounts int +} + +func (m *Manager) ProcessTransactions(ctx context.Context) { + defer m.crashTrackerClient.FlushEvents(2 * time.Second) + defer m.crashTrackerClient.Recover() + log.Ctx(ctx).Infof("Starting %s...", serviceName) + + // initialize signal channel, to react to OS signals + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT) + + ticker := time.NewTicker(m.queueService.pollingInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Ctx(ctx).Infof("Stopping %s due to context cancellation...", serviceName) + return + + case sig := <-signalChan: + log.Ctx(ctx).Infof("Stopping %s due to OS signal '%+v'", serviceName, sig) + return + + case <-ticker.C: + log.Ctx(ctx).Debug("Loading transactions from database...") + jobs, err := m.loadReadyForProcessingBundles(ctx) + if err != nil { + err = fmt.Errorf("attempting to load transactions from database: %w", err) + if errors.Is(err, store.ErrInsuficientChannelAccounts) { + // TODO: should we handle 'errors.Is(err, ErrInsuficientChannelAccounts)' differently? + log.Ctx(ctx).Warn(err) + } else { + m.crashTrackerClient.LogAndReportErrors(ctx, err, "") + } + continue + } + + log.Ctx(ctx).Debugf("Loaded '%d' transactions from database", len(jobs)) + + for _, job := range jobs { + worker, err := NewTransactionWorker( + m.dbConnectionPool, + m.txModel, + m.chAccModel, + m.engine, + m.sigService, + m.maxBaseFee, + m.crashTrackerClient, + m.txProcessingLimiter, + ) + if err != nil { + m.crashTrackerClient.LogAndReportErrors(ctx, err, "") + continue + } + + txJob := TxJob(*job) + go worker.Run(ctx, &txJob) + } + } + } +} + +// loadReadyForProcessingBundles loads a list of {channelAccount, Transaction, LedgerBoundsMax} bundles from the +// database which are ready to be processed. The bundles are locked for processing ar rge database, so that other +// instances of the process don't pick them up. +func (m *Manager) loadReadyForProcessingBundles(ctx context.Context) ([]*store.ChannelTransactionBundle, error) { + currentLedgerNumber, err := m.engine.LedgerNumberTracker.GetLedgerNumber() + if err != nil { + return nil, fmt.Errorf("getting current ledger number: %w", err) + } + lockToLedgerNumber := currentLedgerNumber + engine.IncrementForMaxLedgerBounds + + chTxBundles, err := m.chTxBundleModel.LoadAndLockTuples(ctx, currentLedgerNumber, lockToLedgerNumber, m.txProcessingLimiter.LimitValue()) + if err != nil { + return nil, fmt.Errorf("loading channel transaction bundles: %w", err) + } + + return chTxBundles, nil +} diff --git a/internal/transactionsubmission/manager_test.go b/internal/transactionsubmission/manager_test.go new file mode 100644 index 000000000..b8786f24c --- /dev/null +++ b/internal/transactionsubmission/manager_test.go @@ -0,0 +1,485 @@ +package transactionsubmission + +import ( + "context" + "strings" + "syscall" + "testing" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + storeMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_SubmitterOptions_validate(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + testCases := []struct { + name string + wantErrContains string + submitterOptions SubmitterOptions + }{ + { + name: "validate DatabaseDSN", + submitterOptions: SubmitterOptions{}, + wantErrContains: "database DSN cannot be empty", + }, + { + name: "validate monitorService", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + }, + wantErrContains: "monitor service cannot be nil", + }, + { + name: "validate horizonURL", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + }, + wantErrContains: "horizon url cannot be empty", + }, + { + name: "validate networkPassphrase", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + }, + wantErrContains: "network passphrase \"\" is invalid", + }, + { + name: "validate PrivateKeyEncrypter", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + }, + wantErrContains: "private key encrypter cannot be nil", + }, + { + name: "validate DistributionSeed", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + }, + wantErrContains: "distribution seed is invalid", + }, + { + name: "validate NumChannelAccounts (min)", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 0, + }, + wantErrContains: "num channel accounts must stay in the range from 1 to 1000", + }, + { + name: "validate NumChannelAccounts (min)", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 1001, + }, + wantErrContains: "num channel accounts must stay in the range from 1 to 1000", + }, + { + name: "validate QueuePollingInterval", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 1, + }, + wantErrContains: "queue polling interval must be greater than 6 seconds", + }, + { + name: "validate MaxBaseFee", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 1, + QueuePollingInterval: 10, + }, + wantErrContains: "max base fee must be greater than or equal to 100", + }, + { + name: "πŸŽ‰ successfully finishes validation with nil crash tracker client", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 1, + QueuePollingInterval: 10, + MaxBaseFee: txnbuild.MinBaseFee, + }, + }, + { + name: "πŸŽ‰ successfully finishes validation with existing crash tracker client", + submitterOptions: SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 1, + QueuePollingInterval: 10, + MaxBaseFee: txnbuild.MinBaseFee, + CrashTrackerClient: &crashtracker.MockCrashTrackerClient{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.submitterOptions.validate() + if tc.wantErrContains == "" { + assert.NoError(t, err) + } else { + assert.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } + }) + } +} + +func Test_NewManager(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + validSubmitterOptions := SubmitterOptions{ + DatabaseDSN: dbt.DSN, + MonitorService: &monitor.MonitorService{}, + HorizonURL: "https://horizon-testnet.stellar.org", + NetworkPassphrase: network.TestNetworkPassphrase, + PrivateKeyEncrypter: &utils.PrivateKeyEncrypterMock{}, + DistributionSeed: "SBDBQFZIIZ53A7JC2X23LSQLI5RTKV5YWDRT33YXW5LRMPKRSJYXS2EW", + NumChannelAccounts: 5, + QueuePollingInterval: 10, + MaxBaseFee: txnbuild.MinBaseFee, + } + + testCases := []struct { + name string + getSubmitterOptionsFn func() SubmitterOptions + numOfChannelAccountsToCreate int + wantCrashTrackerClientFn func() crashtracker.CrashTrackerClient + wantErrContains string + }{ + { + name: "returns an error if the SubmitterOptions validation fails", + wantErrContains: "validating options: ", + }, + { + name: "returns an error if the database connection cannot be opened", + getSubmitterOptionsFn: func() SubmitterOptions { + opts := validSubmitterOptions + opts.DatabaseDSN = "invalid-dsn" + return opts + }, + wantErrContains: "opening db connection pool: error pinging app DB connection pool: ", + }, + { + name: "returns an error if there are zero channel accounts in the database", + getSubmitterOptionsFn: func() SubmitterOptions { return validSubmitterOptions }, + wantErrContains: "no channel accounts found in the database, use the 'channel-accounts ensure' command to configure the number of accounts you want to use", + }, + { + name: "πŸŽ‰ Successfully creates a submitter manager. Num of channel accounts intended is EXACT MATCH (Crash Tracker initially nil)", + getSubmitterOptionsFn: func() SubmitterOptions { return validSubmitterOptions }, + numOfChannelAccountsToCreate: 5, + wantCrashTrackerClientFn: func() crashtracker.CrashTrackerClient { + crashTrackerClient, innerErr := crashtracker.NewDryRunClient() + require.NoError(t, innerErr) + return crashTrackerClient + }, + }, + { + name: "πŸŽ‰ Successfully creates a submitter manager. Num of channel accounts intended is EXACT MATCH (Crash Tracker initially not nil)", + getSubmitterOptionsFn: func() SubmitterOptions { + opts := validSubmitterOptions + opts.CrashTrackerClient, err = crashtracker.NewDryRunClient() + require.NoError(t, err) + return opts + }, + numOfChannelAccountsToCreate: 5, + }, + { + name: "πŸŽ‰ Successfully creates a submitter manager. Num of channel accounts intended is SMALLER than intended (Crash Tracker initially nil)", + numOfChannelAccountsToCreate: 1, + getSubmitterOptionsFn: func() SubmitterOptions { + opts := validSubmitterOptions + opts.CrashTrackerClient, err = crashtracker.NewDryRunClient() + require.NoError(t, err) + return opts + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + + // override empty options with the one from getSubmitterOptionsFn() + submitterOptions := SubmitterOptions{} + if tc.getSubmitterOptionsFn != nil { + submitterOptions = tc.getSubmitterOptionsFn() + } + + // create the channel accounts in the DB, if `tc.numOfChannelAccountsToCreate > 0` + if tc.numOfChannelAccountsToCreate > 0 { + _ = store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, tc.numOfChannelAccountsToCreate) + } + + getLogEntries := log.DefaultLogger.StartTest(log.WarnLevel) + gotManager, err := NewManager(ctx, submitterOptions) + logEntries := getLogEntries() + + if tc.wantErrContains != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErrContains) + require.Nil(t, gotManager) + + } else { + require.NoError(t, err) + require.NotNil(t, gotManager) + assert.NotEmpty(t, gotManager.dbConnectionPool) + defer gotManager.dbConnectionPool.Close() + + // Assert the resulting manager state: + wantConnectionPool := gotManager.dbConnectionPool + wantTxModel := &store.TransactionModel{DBConnectionPool: wantConnectionPool} + wantChAccModel := &store.ChannelAccountModel{DBConnectionPool: wantConnectionPool} + wantChTxBundleModel, err := store.NewChannelTransactionBundleModel(wantConnectionPool) + require.NoError(t, err) + + wantSubmitterEngine, err := engine.NewSubmitterEngine(&horizonclient.Client{ + HorizonURL: submitterOptions.HorizonURL, + HTTP: httpclient.DefaultClient(), + }) + require.NoError(t, err) + + wantSigService, err := engine.NewDefaultSignatureService( + submitterOptions.NetworkPassphrase, + wantConnectionPool, + submitterOptions.DistributionSeed, wantChAccModel, + submitterOptions.PrivateKeyEncrypter, + submitterOptions.DistributionSeed, + ) + require.NoError(t, err) + + wantCrashTrackerClient := submitterOptions.CrashTrackerClient + if tc.wantCrashTrackerClientFn != nil { + wantCrashTrackerClient = tc.wantCrashTrackerClientFn() + } + + txProcessingLimiter := engine.NewTransactionProcessingLimiter(submitterOptions.NumChannelAccounts) + txProcessingLimiter.CounterLastUpdated = gotManager.txProcessingLimiter.CounterLastUpdated + wantManager := &Manager{ + dbConnectionPool: wantConnectionPool, + chAccModel: wantChAccModel, + txModel: wantTxModel, + chTxBundleModel: wantChTxBundleModel, + + queueService: defaultQueueService{ + pollingInterval: time.Duration(submitterOptions.QueuePollingInterval) * time.Second, + numChannelAccounts: submitterOptions.NumChannelAccounts, + }, + + engine: wantSubmitterEngine, + sigService: wantSigService, + maxBaseFee: submitterOptions.MaxBaseFee, + + crashTrackerClient: wantCrashTrackerClient, + monitorService: submitterOptions.MonitorService, + + txProcessingLimiter: txProcessingLimiter, + } + assert.Equal(t, wantManager, gotManager) + + if tc.numOfChannelAccountsToCreate < submitterOptions.NumChannelAccounts { + didFindExpectedLogEntry := false + for _, logEntry := range logEntries { + if strings.Contains(logEntry.Message, "The number of channel accounts in the database is smaller than expected") { + didFindExpectedLogEntry = true + } + } + assert.True(t, didFindExpectedLogEntry) + } + + } + }) + } +} + +func Test_Manager_ProcessTransactions(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + type signalType string + const ( + signalTypeCancel signalType = "CANCEL" + signalTypeOSSigterm signalType = "SIGTERM" + signalTypeOSSigint signalType = "SIGINT" + signalTypeOSSigquit signalType = "SIGQUIT" + ) + + testCases := []struct { + signalType signalType + }{ + {signalTypeCancel}, + {signalTypeOSSigterm}, + {signalTypeOSSigint}, + {signalTypeOSSigquit}, + } + + for _, tc := range testCases { + t.Run(string(tc.signalType), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + defer store.DeleteAllFromChannelAccounts(t, context.Background(), dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, context.Background(), dbConnectionPool) + + // Create channel accounts to be used by the tx submitter + channelAccounts := store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 2) + assert.Len(t, channelAccounts, 2) + channelAccountsMap := map[string]*store.ChannelAccount{} + for _, ca := range channelAccounts { + channelAccountsMap[ca.PublicKey] = ca + } + + // Create transactions to be used by the tx submitter + transactions := store.CreateTransactionFixtures(t, ctx, dbConnectionPool, 10, "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", keypair.MustRandom().Address(), store.TransactionStatusPending, 1) + assert.Len(t, transactions, 10) + + // Signature service + distributionKP := keypair.MustRandom() + sigService, err := engine.NewDefaultSignatureService( + network.TestNetworkPassphrase, + dbConnectionPool, + distributionKP.Seed(), + store.NewChannelAccountModel(dbConnectionPool), + &utils.PrivateKeyEncrypterMock{}, + distributionKP.Seed(), + ) + require.NoError(t, err) + + // mock ledger number tracker + const currentLedgerNumber = 123 + mockLedgerNumberTracker := &mocks.MockLedgerNumberTracker{} + mockLedgerNumberTracker.On("GetLedgerNumber").Return(currentLedgerNumber, nil) + defer mockLedgerNumberTracker.AssertExpectations(t) + + // mock horizon client + const sequenceNumber = 456 + mockHorizonClient := &horizonclient.MockClient{} + mockHorizonClient.On("AccountDetail", mock.AnythingOfType("horizonclient.AccountRequest")).Return(horizon.Account{Sequence: sequenceNumber}, nil) + mockChannelAccountStore := &storeMocks.MockChannelAccountStore{} + for pubKey, ca := range channelAccountsMap { + mockChannelAccountStore.On("Get", ctx, mock.Anything, pubKey, 0).Return(ca, nil) + } + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + mockHorizonClient. + On("SubmitFeeBumpTransactionWithOptions", mock.AnythingOfType("*txnbuild.FeeBumpTransaction"), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(horizon.Transaction{Successful: true, ResultXdr: resultXDR}, nil).Twice() + defer mockHorizonClient.AssertExpectations(t) + + submitterEngine := &engine.SubmitterEngine{ + LedgerNumberTracker: mockLedgerNumberTracker, + HorizonClient: mockHorizonClient, + } + + dryRunCrashTracker, err := crashtracker.NewDryRunClient() + require.NoError(t, err) + + queueService := defaultQueueService{ + pollingInterval: 500 * time.Millisecond, + numChannelAccounts: 2, + } + + chTxBundleModel, err := store.NewChannelTransactionBundleModel(dbConnectionPool) + require.NoError(t, err) + + manager := &Manager{ + dbConnectionPool: dbConnectionPool, + chTxBundleModel: chTxBundleModel, + chAccModel: store.NewChannelAccountModel(dbConnectionPool), + txModel: store.NewTransactionModel(dbConnectionPool), + engine: submitterEngine, + crashTrackerClient: dryRunCrashTracker, + queueService: queueService, + sigService: sigService, + maxBaseFee: txnbuild.MinBaseFee, + txProcessingLimiter: engine.NewTransactionProcessingLimiter(queueService.numChannelAccounts), + } + + go manager.ProcessTransactions(ctx) + time.Sleep(750 * time.Millisecond) // <--- this time.Sleep is used wait for the manager (QueuePollingInterval) to start and load the transactions. + + // cancel() + switch tc.signalType { + case signalTypeOSSigterm: + err = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + require.NoError(t, err) + + case signalTypeOSSigint: + err = syscall.Kill(syscall.Getpid(), syscall.SIGINT) + require.NoError(t, err) + + case signalTypeOSSigquit: + err = syscall.Kill(syscall.Getpid(), syscall.SIGQUIT) + require.NoError(t, err) + } + + cancel() + }) + } +} diff --git a/internal/transactionsubmission/scripts/README.md b/internal/transactionsubmission/scripts/README.md new file mode 100644 index 000000000..eeae1fca8 --- /dev/null +++ b/internal/transactionsubmission/scripts/README.md @@ -0,0 +1,45 @@ +# Transaction Submission Payments Load Testing Script + +Load Test Flow: +1) Create N number of Transactions in the database for TSS to process +2) Poll database until all Transactions have completed +3) Query Horizon for each transaction and use those details to calculate and print metrics for the load test + +### CLI Flags: +```sh + --databaseUrl Postgres DB URL + --horizonUrl Horizon URL (default "https://horizon-testnet.stellar.org") + --assetCode Asset code (default "USDC") + --assetIssuer Asset issuer (default "GDQOE23CFSUMSVQK4Y5JHPPYK73VYCNHZHA7ENKCV37P6SUEO6XQBKPP") + --paymentDestination Destination address to send the payments to + --paymentCount Number of payment Transactions to create +``` + +### CLI Usage Example: +```sh +% go run internal/transactionsubmission/scripts/tss_payments_loadtest.go \ +--databaseUrl "postgres://postgres:password@localhost:5432/tss-testing?sslmode=disable" \ +--horizonUrl "https://horizon.stellar.org" \ +--assetCode "USDS" -assetIssuer "GCDUFCM7HA2AXFPWCXI55MXMCPORHOE42YIIBKN72SAMZ6WBO3G2E5TF" \ +--paymentDestination "GAR5YLLLSTPOJGK2T5P5WMSVGEFWQLDMPMZXICURGBUYJOVXARI2ZTXI" \ +--paymentCount 3 + + +All 3 transactions have completed! +Test size: 3 payment(s) +========================================================== +TSS first created payment time: 2023-06-21 11:06:55 +Stellar first observed payment time: 2023-06-21 11:07:02 +TSS last created payment time: 2023-06-21 11:06:55 +Stellar final payment observed time: 2023-06-21 11:07:26 +========================================================= +Total test latency (first created, last observed): 30.29 +========================================================== +min e2e payment latency: 6.29 +average e2e payment latency: 16.28 +max e2e payment latency: 30.28 +========================================================== +calculated average TPS: 0.10 +unique ledgers: 3 +========================================================== +``` diff --git a/internal/transactionsubmission/scripts/tss_payments_loadtest.go b/internal/transactionsubmission/scripts/tss_payments_loadtest.go new file mode 100644 index 000000000..2c60b4369 --- /dev/null +++ b/internal/transactionsubmission/scripts/tss_payments_loadtest.go @@ -0,0 +1,209 @@ +package main + +import ( + "context" + "flag" + "fmt" + "math" + "sort" + "time" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" +) + +// calculateAndPrintMetrics gets the transaction details from Horizon and calculates and prints metrics. +func calculateAndPrintMetrics(ctx context.Context, horizonClient *horizonclient.Client, txModel *store.TransactionModel, transactionIDs []string) { + transactionsTSS := make(map[string]*store.Transaction) + transactionsStellar := make(map[string]*horizon.Transaction) + transactionLatencies := make([]time.Duration, 0, len(transactionIDs)) + uniqueLedgers := make(map[int32]bool) + for _, transactionID := range transactionIDs { + tx, _ := txModel.Get(ctx, transactionID) + transactionsTSS[transactionID] = tx + } + + for txnId, txn := range transactionsTSS { + stellarTxn, err := horizonClient.TransactionDetail(txn.StellarTransactionHash.String) + // time.Sleep(100000) might need to sleep if getting rate limited + if err != nil { + fmt.Printf("failed to retrieve stellar transaction %s from horizon", txn.StellarTransactionHash.String) + } + transactionsStellar[txnId] = &stellarTxn + } + + minCreatedPaymentTime := time.Now() + for _, tx := range transactionsTSS { + if tx.CreatedAt.Before(minCreatedPaymentTime) { + minCreatedPaymentTime = *tx.CreatedAt + } + } + + maxCreatedPaymentTime := minCreatedPaymentTime + for _, tx := range transactionsTSS { + if tx.CreatedAt.After(maxCreatedPaymentTime) { + maxCreatedPaymentTime = *tx.CreatedAt + } + } + + minStellarTxnCreatedTime := time.Now() + maxStellarTxnCreatedTime := time.Time{} + for _, tx := range transactionsStellar { + if tx.LedgerCloseTime.Before(minStellarTxnCreatedTime) { + minStellarTxnCreatedTime = tx.LedgerCloseTime + } + if tx.LedgerCloseTime.After(maxStellarTxnCreatedTime) { + maxStellarTxnCreatedTime = tx.LedgerCloseTime + } + uniqueLedgers[tx.Ledger] = true + } + + minTxnLatency := time.Duration(math.MaxInt64) + maxTxnLatency := time.Duration(math.MinInt64) + for _, txId := range transactionIDs { + start := transactionsTSS[txId].CreatedAt + finish := transactionsStellar[txId].LedgerCloseTime + duration := finish.Sub(*start) + transactionLatencies = append(transactionLatencies, duration) + if duration < minTxnLatency { + minTxnLatency = duration + } + if duration > maxTxnLatency { + maxTxnLatency = duration + } + } + + sumLatency := time.Duration(0) + for _, duration := range transactionLatencies { + sumLatency += duration + } + avgLatency := sumLatency / time.Duration(len(transactionLatencies)) + + sort.Slice(transactionLatencies, func(i, j int) bool { + return transactionLatencies[i] < transactionLatencies[j] + }) + mid := len(transactionLatencies) / 2 + medianLatency := transactionLatencies[mid] + + fmt.Printf("Test size: %d payment(s)\n", len(transactionIDs)) + fmt.Printf("==========================================================\n") + fmt.Printf("TSS first created payment time: %s\n", minCreatedPaymentTime.In(time.Local).Format(time.Stamp)) + fmt.Printf("Stellar first observed payment time: %s\n", minStellarTxnCreatedTime.In(time.Local).Format(time.Stamp)) + fmt.Printf("TSS last created payment time: %s\n", maxCreatedPaymentTime.In(time.Local).Format(time.Stamp)) + fmt.Printf("Stellar final payment observed time: %s\n", maxStellarTxnCreatedTime.In(time.Local).Format(time.Stamp)) + fmt.Printf("=========================================================\n") + fmt.Printf("Total test latency (first created, last observed): %.2fs\n", maxStellarTxnCreatedTime.Sub(minCreatedPaymentTime).Seconds()) + fmt.Printf("==========================================================\n") + fmt.Printf("min e2e payment latency: %.2fs\n", minTxnLatency.Seconds()) + fmt.Printf("average e2e payment latency: %.2fs\n", avgLatency.Seconds()) + fmt.Printf("max e2e payment latency: %.2fs\n", maxTxnLatency.Seconds()) + fmt.Printf("==========================================================\n") + fmt.Printf("calculated average TPS: %.2f\n", float64(len(transactionIDs))/(maxStellarTxnCreatedTime.Sub(minCreatedPaymentTime).Seconds())) + fmt.Printf("unique ledgers: %d\n", len(uniqueLedgers)) + fmt.Printf("==========================================================\n\n") + + fmt.Printf("%.2f, %d, %d, %d, %d, %s, %s, %.2f, %.2f, %.2f, %.2f, %.2f\n\n", + float64(len(transactionIDs))/(maxStellarTxnCreatedTime.Sub(minCreatedPaymentTime).Seconds()), + 0, + 0, + 0, + len(uniqueLedgers), + minCreatedPaymentTime.In(time.Local).Format("2006-01-02 15:04:05"), + minStellarTxnCreatedTime.In(time.Local).Format("2006-01-02 15:04:05"), + maxStellarTxnCreatedTime.Sub(minCreatedPaymentTime).Seconds(), + minTxnLatency.Seconds(), + medianLatency.Seconds(), + avgLatency.Seconds(), + maxTxnLatency.Seconds(), + ) +} + +// createPaymentTransactions creates bulk transactions in the submitter_transactions table for TSS to process. +func createPaymentTransactions(ctx context.Context, txModel *store.TransactionModel, paymentCount int, assetCode, assetIssuer, destination string) []store.Transaction { + transactions := make([]store.Transaction, 0, paymentCount) + for i := 0; i < paymentCount; i++ { + externalID := fmt.Sprintf("external-id-%d", i) + transactions = append(transactions, store.Transaction{ + ExternalID: externalID, + AssetCode: assetCode, + AssetIssuer: assetIssuer, + Amount: 0.1, + Destination: destination, + }) + } + insertedTransactions, err := txModel.BulkInsert(ctx, txModel.DBConnectionPool, transactions) + if err != nil { + log.Ctx(ctx).Errorf("Error inserting transactions: %v", err.Error()) + } + return insertedTransactions +} + +// waitForTransactionsToComplete queries the database for each transaction that was created as waits for all of them to +// be in either SUCCESS or ERROR state. +func waitForTransactionsToComplete(ctx context.Context, txModel *store.TransactionModel, transactionIDs []string) { + ticker := time.NewTicker(10 * time.Second) + tickerChan := ticker.C + for range tickerChan { + completedTransactions := 0 + for _, transactionID := range transactionIDs { + // TODO - optimize this into a single query to get all statuses + tx, err := txModel.Get(ctx, transactionID) + if err != nil { + log.Ctx(ctx).Errorf("Error getting transaction %s: %v", transactionID, err) + } else if tx.Status == store.TransactionStatusError || tx.Status == store.TransactionStatusSuccess { + completedTransactions += 1 + } + } + + if completedTransactions == len(transactionIDs) { + // All transactions are complete, exit the loop + ticker.Stop() + fmt.Printf("All %d transactions have completed!\n", len(transactionIDs)) + return + } + fmt.Printf("%d/%d transactions have completed...\n", completedTransactions, len(transactionIDs)) + } +} + +// This script is just meant for creating a large number of payments for TESTING. +// There is minimal error handling and minimal checking for valid input parameters. +func main() { + paymentCount := flag.Int("paymentCount", 0, "how many payments to create") + databaseUrl := flag.String("databaseUrl", "", "database to create the transactions in") + horizonUrl := flag.String("horizonUrl", "https://horizon-testnet.stellar.org", "horizon url") + assetCode := flag.String("assetCode", "USDC", "asset code") + assetIssuer := flag.String("assetIssuer", "GDQOE23CFSUMSVQK4Y5JHPPYK73VYCNHZHA7ENKCV37P6SUEO6XQBKPP", "asset issuer") + paymentDestination := flag.String("paymentDestination", "", "destination address of the payment") + flag.Parse() + + ctx := context.Background() + dbConnectionPool, err := db.OpenDBConnectionPool(*databaseUrl) + if err != nil { + fmt.Printf("Error opening db connection pool in init: %s ", err.Error()) + } + + txModel := &store.TransactionModel{DBConnectionPool: dbConnectionPool} + + // create horizon client + horizonClient := &horizonclient.Client{ + HorizonURL: *horizonUrl, + HTTP: httpclient.DefaultClient(), + } + + // 1) create the payment transactions + transactionIDs := createPaymentTransactions(ctx, txModel, *paymentCount, *assetCode, *assetIssuer, *paymentDestination) + txIDs := make([]string, 0, len(transactionIDs)) + for _, tx := range transactionIDs { + txIDs = append(txIDs, tx.ID) + } + + // 2) wait for all Transactions to be marked as either Success/Error + waitForTransactionsToComplete(ctx, txModel, txIDs) + + // 3) calculate and print metrics + calculateAndPrintMetrics(ctx, horizonClient, txModel, txIDs) +} diff --git a/internal/transactionsubmission/services/channel_account_service.go b/internal/transactionsubmission/services/channel_account_service.go new file mode 100644 index 000000000..584b821ae --- /dev/null +++ b/internal/transactionsubmission/services/channel_account_service.go @@ -0,0 +1,353 @@ +package services + +import ( + "context" + "fmt" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/log" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + txSub "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" +) + +type ChannelAccountsService struct { + dbConnectionPool db.DBConnectionPool + caStore store.ChannelAccountStore + horizonClient horizonclient.ClientInterface + ledgerNumberTracker engine.LedgerNumberTracker +} + +type ChannelAccountsServiceInterface interface { + CreateChannelAccountsOnChain(context.Context, ChannelAccountServiceOptions) error + VerifyChannelAccounts(context.Context, ChannelAccountServiceOptions) error + DeleteChannelAccount(context.Context, ChannelAccountServiceOptions) error + EnsureChannelAccountsCount(context.Context, ChannelAccountServiceOptions) error + ViewChannelAccounts(context.Context) error +} + +// make sure *ChannelAccountsService implements ChannelAccountsServiceInterface: +var _ ChannelAccountsServiceInterface = (*ChannelAccountsService)(nil) + +type ChannelAccountServiceOptions struct { + ChannelAccountID string + DatabaseDSN string + DeleteAllAccounts bool + DeleteInvalidAcccounts bool + EncryptKey bool + HorizonUrl string + MaxBaseFee int + NetworkPassphrase string + NumChannelAccounts int + RootSeed string +} + +func NewChannelAccountService(opts ChannelAccountServiceOptions) (*ChannelAccountsService, error) { + dbConnectionPool, err := db.OpenDBConnectionPool(opts.DatabaseDSN) + if err != nil { + return nil, fmt.Errorf("opening db connection pool: %w", err) + } + + caModel := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + horizonClient := &horizonclient.Client{ + HorizonURL: opts.HorizonUrl, + HTTP: httpclient.DefaultClient(), + } + + ledgerNumberTracker, err := engine.NewLedgerNumberTracker(horizonClient) + if err != nil { + return nil, fmt.Errorf("cannot create new ledger number tracker") + } + + return &ChannelAccountsService{ + dbConnectionPool: dbConnectionPool, + caStore: caModel, + horizonClient: horizonClient, + ledgerNumberTracker: ledgerNumberTracker, + }, nil +} + +// CreateChannelAccountsOnChain creates a specified count of sponsored channel accounts onchain and internally in the database. +func (s *ChannelAccountsService) CreateChannelAccountsOnChain(ctx context.Context, opts ChannelAccountServiceOptions) error { + log.Ctx(ctx).Infof("NumChannelAccounts: %d, Horizon: %s, Passphrase: %s, EncryptKey?: %t", opts.NumChannelAccounts, opts.HorizonUrl, opts.NetworkPassphrase, opts.EncryptKey) + // createAccountsInBatch creates count number of channel accounts in batches of MaxBatchSize or less per loop + err := createAccountsInBatch(ctx, s.dbConnectionPool, opts, s.horizonClient, s.caStore, s.ledgerNumberTracker) + if err != nil { + return fmt.Errorf("creating channel accounts in batch in CreateChannelAccountsOnChain: %w", err) + } + + return nil +} + +func createAccountsInBatch( + ctx context.Context, + dbConnectionPool db.DBConnectionPool, + opts ChannelAccountServiceOptions, + horizonClient horizonclient.ClientInterface, + chAccModel store.ChannelAccountStore, + ledgerNumberTracker engine.LedgerNumberTracker, +) error { + sigService, err := engine.NewDefaultSignatureService(opts.NetworkPassphrase, dbConnectionPool, opts.RootSeed, chAccModel, &utils.DefaultPrivateKeyEncrypter{}, opts.RootSeed) + if err != nil { + return fmt.Errorf("creating signature service: %w", err) + } + + numberOfAccountsToCreate := opts.NumChannelAccounts + for numberOfAccountsToCreate > 0 { + batchSize := numberOfAccountsToCreate + if numberOfAccountsToCreate > txSub.MaximumCreateAccountOperationsPerStellarTx { + // only create a MaxBatchSize (19) of accounts per transaction, this is due to the signature limit of a transaction + batchSize = txSub.MaximumCreateAccountOperationsPerStellarTx + } + log.Ctx(ctx).Infof("batch size: %d", batchSize) + + currLedgerNumber, err := ledgerNumberTracker.GetLedgerNumber() + if err != nil { + return fmt.Errorf("cannot get current ledger number: %w", err) + } + accounts, err := txSub.CreateChannelAccountsOnChain( + ctx, + horizonClient, + batchSize, + opts.MaxBaseFee, + opts.EncryptKey, + sigService, + currLedgerNumber, + ) + if err != nil { + return err + } + + // write the channel accounts to the database + for _, account := range accounts { + _, err = chAccModel.Unlock(ctx, dbConnectionPool, account) + if err != nil { + return fmt.Errorf("cannot unlock account %s", account) + } + log.Ctx(ctx).Infof("Created channel account with public key %s", account) + } + numberOfAccountsToCreate -= len(accounts) + } + + return nil +} + +// VerifyChannelAccounts verifies the existance of all channel accounts in the data store onchain. +func (c *ChannelAccountsService) VerifyChannelAccounts(ctx context.Context, opts ChannelAccountServiceOptions) error { + log.Ctx(ctx).Infof("DeleteInvalidAccounts?: %t", opts.DeleteInvalidAcccounts) + accounts, err := c.caStore.GetAll(ctx, c.dbConnectionPool, 0, 0) + if err != nil { + return fmt.Errorf("loading channel accounts from database in VerifyChannelAccounts: %w", err) + } + + log.Ctx(ctx).Infof("Discovered %d channel accounts in database", len(accounts)) + + invalidAccountsCount := 0 + for _, account := range accounts { + _, err := c.horizonClient.AccountDetail(horizonclient.AccountRequest{AccountID: account.PublicKey}) + if err != nil { + if horizonclient.IsNotFoundError(err) { + log.Ctx(ctx).Warnf("Account %s does not exist on the network", account.PublicKey) + if opts.DeleteInvalidAcccounts { + deleteErr := c.caStore.Delete(ctx, c.dbConnectionPool, account.PublicKey) + if deleteErr != nil { + return fmt.Errorf( + "deleting %s from database in VerifyChannelAccounts: %w", + account.PublicKey, + deleteErr, + ) + } + log.Ctx(ctx).Infof("Successfully deleted channel account %q", account.PublicKey) + } + + invalidAccountsCount++ + } else { + // return any error other than 404's + return fmt.Errorf( + "retrieving account details through horizon for account %s in VerifyChannelAccounts: %w", + account.PublicKey, + horizonclient.GetError(err), + ) + } + } + } + + if invalidAccountsCount == 0 { + log.Ctx(ctx).Info("No invalid channel accounts discovered") + } + + return nil +} + +func (s *ChannelAccountsService) EnsureChannelAccountsCount( + ctx context.Context, + opts ChannelAccountServiceOptions, +) error { + log.Ctx(ctx).Infof("Desired Accounts Count: %d", opts.NumChannelAccounts) + + numAccountsToEnsure := opts.NumChannelAccounts + if numAccountsToEnsure > txSub.MaxNumberOfChannelAccounts { + return fmt.Errorf( + "count entered %d is greater than the channel accounts count limit %d in EnsureChannelAccountsCount", + numAccountsToEnsure, + txSub.MaxNumberOfChannelAccounts, + ) + } + + accountsCount, err := s.caStore.Count(ctx) + if err != nil { + return fmt.Errorf("retrieving channel accounts count in EnsureChannelAccountsCount: %w", err) + } + + if accountsCount == numAccountsToEnsure { + log.Ctx(ctx).Infof("There are exactly %d managed channel accounts currently. Exiting...", numAccountsToEnsure) + return nil + } else if accountsCount > numAccountsToEnsure { // delete some accounts + numAccountsToDelete := accountsCount - numAccountsToEnsure + log.Ctx(ctx).Infof("Deleting %d accounts...", numAccountsToDelete) + + err = s.deleteChannelAccounts(ctx, opts, numAccountsToDelete) + if err != nil { + return fmt.Errorf("deleting %d accounts in EnsureChannelAccountsCount: %w", numAccountsToDelete, err) + } + } else { // add some accounts + numAccountsToCreate := numAccountsToEnsure - accountsCount + opts.NumChannelAccounts = numAccountsToCreate + log.Ctx(ctx).Infof("Creating %d accounts...", numAccountsToCreate) + + createAccErr := createAccountsInBatch(ctx, s.dbConnectionPool, opts, s.horizonClient, s.caStore, s.ledgerNumberTracker) + if createAccErr != nil { + return fmt.Errorf("creating channel accounts in batch in EnsureChannelAccountsCount: %w", createAccErr) + } + } + + return nil +} + +// DeleteChannelAccount removes a specified channel account from the database and onchain. +func (s *ChannelAccountsService) DeleteChannelAccount( + ctx context.Context, + opts ChannelAccountServiceOptions, +) error { + if opts.ChannelAccountID != "" { // delete specified accounts + currLedgerNum, err := s.ledgerNumberTracker.GetLedgerNumber() + if err != nil { + return fmt.Errorf("retrieving current ledger number in DeleteChannelAccount: %w", err) + } + + lockedUntilLedgerNumber := currLedgerNum + engine.IncrementForMaxLedgerBounds + channelAccount, err := s.caStore.GetAndLock(ctx, opts.ChannelAccountID, currLedgerNum, lockedUntilLedgerNumber) + if err != nil { + return fmt.Errorf( + "retrieving account %s from database in DeleteChannelAccount: %w", opts.ChannelAccountID, err) + } + + err = s.deleteChannelAccount(ctx, opts, channelAccount.PublicKey, lockedUntilLedgerNumber) + if err != nil { + return fmt.Errorf("deleting account %s in DeleteChannelAccount: %w", channelAccount.PublicKey, err) + } + } else if opts.DeleteAllAccounts { // delete all managed accounts + accountsCount, err := s.caStore.Count(ctx) + log.Ctx(ctx).Infof("Found %d accounts to delete...", accountsCount) + + if err != nil { + return fmt.Errorf("cannot get count for accounts in DeleteChannelAccount: %w", err) + } + err = s.deleteChannelAccounts(ctx, opts, accountsCount) + if err != nil { + return fmt.Errorf("cannot delete all accounts in DeleteChannelAccount: %w", err) + } + } else { + log.Ctx(ctx).Warn("Specify an account to delete or enable deletion of all accounts") + } + + return nil +} + +func (s *ChannelAccountsService) ViewChannelAccounts(ctx context.Context) error { + accounts, err := s.caStore.GetAll(ctx, s.dbConnectionPool, 0, 0) + if err != nil { + return fmt.Errorf("loading channel accounts from database in ViewChannelAccounts: %w", err) + } + + log.Ctx(ctx).Infof("Discovered %d channel accounts in database...", len(accounts)) + + for _, acc := range accounts { + log.Ctx(ctx).Infof("Found account %s", acc.PublicKey) + } + + return nil +} + +func (s *ChannelAccountsService) deleteChannelAccounts(ctx context.Context, opts ChannelAccountServiceOptions, numAccountsToDelete int) error { + for i := 0; i < numAccountsToDelete; i++ { + currLedgerNum, err := s.ledgerNumberTracker.GetLedgerNumber() + if err != nil { + return fmt.Errorf("retrieving current ledger number in DeleteChannelAccount: %w", err) + } + + lockedUntilLedgerNumber := currLedgerNum + engine.IncrementForMaxLedgerBounds + accounts, err := s.caStore.GetAndLockAll(ctx, currLedgerNum, lockedUntilLedgerNumber, 1) + if err != nil { + return fmt.Errorf("cannot retrieve free channel account: %w", err) + } + + if len(accounts) == 0 { + log.Ctx(ctx).Warn("Could not find any accounts to deleting. Exiting...") + return nil + } + + accountToDelete := accounts[0] + err = s.deleteChannelAccount(ctx, opts, accountToDelete.PublicKey, lockedUntilLedgerNumber) + if err != nil { + return fmt.Errorf("cannot delete account %s: %w", accountToDelete.PublicKey, err) + } + } + + return nil +} + +func (s *ChannelAccountsService) deleteChannelAccount( + ctx context.Context, + opts ChannelAccountServiceOptions, + chAccAddress string, + lockedUntilLedger int, +) error { + sigService, err := engine.NewDefaultSignatureService(opts.NetworkPassphrase, s.dbConnectionPool, opts.RootSeed, s.caStore, &utils.DefaultPrivateKeyEncrypter{}, opts.RootSeed) + if err != nil { + return fmt.Errorf("creating signature service: %w", err) + } + + _, err = s.horizonClient.AccountDetail(horizonclient.AccountRequest{AccountID: chAccAddress}) + if err != nil { + if horizonclient.IsNotFoundError(err) { + log.Ctx(ctx).Warnf("Account %s does not exist on the network", chAccAddress) + err = sigService.Delete(ctx, chAccAddress, lockedUntilLedger) + if err != nil { + return fmt.Errorf("deleting %s from signature service: %w", chAccAddress, err) + } + } else { + return fmt.Errorf("cannot find account %s on the network: %w", chAccAddress, err) + } + } else { + err = txSub.DeleteChannelAccountOnChain( + ctx, + s.horizonClient, + chAccAddress, + int64(opts.MaxBaseFee), + sigService, + lockedUntilLedger, + ) + if err != nil { + return fmt.Errorf("deleting account %s onchain: %w", opts.ChannelAccountID, err) + } + } + + log.Ctx(ctx).Infof("Successfully deleted channel account %q", chAccAddress) + + return nil +} diff --git a/internal/transactionsubmission/services/channel_accounts_service_test.go b/internal/transactionsubmission/services/channel_accounts_service_test.go new file mode 100644 index 000000000..a3dae6e55 --- /dev/null +++ b/internal/transactionsubmission/services/channel_accounts_service_test.go @@ -0,0 +1,823 @@ +package services + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/keypair" + "github.com/stellar/go/protocols/horizon" + + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/problem" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + engineMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + storeMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store/mocks" +) + +func Test_ChannelAccounts_CreateAccount_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + opts := ChannelAccountServiceOptions{ + NumChannelAccounts: 2, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + EncryptKey: true, + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currLedgerNumber := 100 + + ctx := context.Background() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil) + mHorizonClient.On( + "SubmitTransactionWithOptions", + mock.Anything, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ).Return(horizon.Transaction{}, nil).Once() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNumber, nil).Once() + mChannelAccountStore.On( + "BatchInsertAndLock", + ctx, + mock.AnythingOfType("[]*store.ChannelAccount"), + currLedgerNumber, + currLedgerNumber+engine.IncrementForMaxLedgerBounds, + ).Return(nil).Once() + mChannelAccountStore.On( + "Get", ctx, dbConnectionPool, mock.AnythingOfType("string"), 0, + ).Return(&store.ChannelAccount{PrivateKey: keypair.MustRandom().Seed()}, nil).Twice() + mChannelAccountStore.On("Unlock", ctx, mock.Anything, mock.AnythingOfType("string")).Return(nil, nil).Twice() + + err = cas.CreateChannelAccountsOnChain(ctx, opts) + require.NoError(t, err) + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) + + store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) +} + +func Test_ChannelAccounts_CreateAccount_CannotFindRootAccount_Failure(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + opts := ChannelAccountServiceOptions{ + NumChannelAccounts: 2, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SDL4E4RF6BHX77DBKE63QC4H4LQG7S7D2PB4TSF64LTHDIHP7UUJHH2V", + EncryptKey: true, + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currLedgerNumber := 100 + + ctx := context.Background() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{}, errors.New("cannot find root account")) + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNumber, nil).Once() + + err = cas.CreateChannelAccountsOnChain(ctx, opts) + require.ErrorContains( + t, + err, + "creating channel accounts in batch in CreateChannelAccountsOnChain: failed to retrieve root account: cannot find root account", + ) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_CreateAccount_Insert_Failure(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + opts := ChannelAccountServiceOptions{ + NumChannelAccounts: 2, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + EncryptKey: true, + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currLedgerNumber := 100 + + ctx := context.Background() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNumber, nil).Once() + mChannelAccountStore.On( + "BatchInsertAndLock", + ctx, + mock.AnythingOfType("[]*store.ChannelAccount"), + currLedgerNumber, + currLedgerNumber+engine.IncrementForMaxLedgerBounds, + ).Return(errors.New("failure inserting tx in DB")) + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil) + + err = cas.CreateChannelAccountsOnChain(ctx, opts) + require.EqualError( + t, + err, + "creating channel accounts in batch in CreateChannelAccountsOnChain: failed to insert channel accounts into signature service: batch inserting channel accounts: failure inserting tx in DB", + ) + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_VerifyAccounts_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + } + + opts := ChannelAccountServiceOptions{ + DeleteInvalidAcccounts: false, + } + + channelAccounts := []*store.ChannelAccount{ + { + PublicKey: "GC3TKX2B6V7RSIU7UWNJ6MIA7PBTVBXGG7B43HYXRDLHB2DI6FVCYDE3", + }, + { + PublicKey: "GAV6VOD2JY6CYJ2XT7U4IH5HL5RJZXEDZFC7CQX5SR7SLLVOP3KPOFH2", + }, + } + + ctx := context.Background() + mChannelAccountStore.On("GetAll", ctx, dbConnectionPool, 0, 0).Return(channelAccounts, nil).Once() + for _, acc := range channelAccounts { + mHorizonClient.On( + "AccountDetail", + horizonclient.AccountRequest{AccountID: acc.PublicKey}, + ).Return(horizon.Account{AccountID: acc.PublicKey}, nil).Once() + } + + err = cas.VerifyChannelAccounts(ctx, opts) + require.NoError(t, err) + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) +} + +func Test_ChannelAccounts_VerifyAccounts_LoadChannelAccountsError_Failure(t *testing.T) { + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: &horizonclient.MockClient{}, + } + + opts := ChannelAccountServiceOptions{ + DeleteInvalidAcccounts: false, + } + + ctx := context.Background() + mChannelAccountStore. + On("GetAll", ctx, nil, 0, 0). + Return(nil, errors.New("cannot load channel accounts from database")). + Once() + + err := cas.VerifyChannelAccounts(ctx, opts) + require.EqualError( + t, + err, + "loading channel accounts from database in VerifyChannelAccounts: cannot load channel accounts from database", + ) + mChannelAccountStore.AssertExpectations(t) +} + +func Test_ChannelAccounts_VerifyAccounts_NotFound(t *testing.T) { + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + } + + opts := ChannelAccountServiceOptions{ + DeleteInvalidAcccounts: true, + } + + channelAccounts := []*store.ChannelAccount{ + { + PublicKey: "GC3TKX2B6V7RSIU7UWNJ6MIA7PBTVBXGG7B43HYXRDLHB2DI6FVCYDE3", + }, + { + PublicKey: "GAV6VOD2JY6CYJ2XT7U4IH5HL5RJZXEDZFC7CQX5SR7SLLVOP3KPOFH2", + }, + } + + ctx := context.Background() + mChannelAccountStore.On("GetAll", ctx, nil, 0, 0).Return(channelAccounts, nil).Once() + for _, acc := range channelAccounts { + mHorizonClient.On( + "AccountDetail", + horizonclient.AccountRequest{AccountID: acc.PublicKey}, + ).Return(horizon.Account{}, horizonclient.Error{ + Problem: problem.P{ + Type: "https://stellar.org/horizon-errors/not_found", + }, + }).Once() + mChannelAccountStore.On("Delete", ctx, nil, acc.PublicKey).Return(nil).Once() + } + + getEntries := log.DefaultLogger.StartTest(log.WarnLevel) + + err := cas.VerifyChannelAccounts(ctx, opts) + require.NoError(t, err) + + entries := getEntries() + assert.Equal(t, len(entries), 2) + for i, entry := range entries { + assert.Equal( + t, + entry.Message, + fmt.Sprintf("Account %s does not exist on the network", channelAccounts[i].PublicKey), + ) + } + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) +} + +func Test_ChannelAccounts_DeleteAccount_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + channelAccount := &store.ChannelAccount{ + PublicKey: "GDXSRISWI6ZVFVVOUU2DNKVHUYEJQZ63A37P6C5NGKXBROW5WW5W6HW3", + PrivateKey: "YVeMG89DMl2Ku7IeGCumrvneDydfuW+2q4EKQoYhPRpKS/A1bKhNzAa7IjyLiA6UwTESsM6Hh8nactmuOfqUT38YVTx68CIgG6OuwCHPrmws57Tf", + } + + opts := ChannelAccountServiceOptions{ + ChannelAccountID: channelAccount.PublicKey, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + DeleteAllAccounts: false, + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currLedgerNum := 100 + + ctx := context.Background() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Once() + mChannelAccountStore.On("GetAndLock", ctx, opts.ChannelAccountID, currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(channelAccount, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: opts.ChannelAccountID}). + Return(horizon.Account{}, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil).Once() + mChannelAccountStore.On("Get", ctx, mock.Anything, opts.ChannelAccountID, 0). + Return(channelAccount, nil).Once() + mHorizonClient.On( + "SubmitTransactionWithOptions", + mock.Anything, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ).Return(horizon.Transaction{}, nil).Once() + mChannelAccountStore.On("DeleteIfLockedUntil", ctx, opts.ChannelAccountID, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(nil).Once() + + err = cas.DeleteChannelAccount(ctx, opts) + require.NoError(t, err) + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_DeleteAccount_All_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + channelAccounts := []*store.ChannelAccount{ + { + PublicKey: "GDXSRISWI6ZVFVVOUU2DNKVHUYEJQZ63A37P6C5NGKXBROW5WW5W6HW3", + PrivateKey: "YVeMG89DMl2Ku7IeGCumrvneDydfuW+2q4EKQoYhPRpKS/A1bKhNzAa7IjyLiA6UwTESsM6Hh8nactmuOfqUT38YVTx68CIgG6OuwCHPrmws57Tf", + }, + { + PublicKey: "GAORBNVUS7TZI6M47CE2XKJIYUZGWTQLPJTU3FEQCFR47H6LTLCTK25P", + PrivateKey: "I9uPlXL/KvZOOK7kVHHjdFaSeJARV/lvv0YG7P2GCYclgz1MCmthiSZv0BF5HK13PmB4qgzMG9cebxShEZ8AjXDHZA4IOrt+4stE6GF8UR8jdWkG", + }, + } + + opts := ChannelAccountServiceOptions{ + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + DeleteAllAccounts: true, + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currLedgerNum := 1000 + + ctx := context.Background() + mChannelAccountStore.On("Count", ctx).Return(len(channelAccounts), nil).Once() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Times(len(channelAccounts)) + for _, acc := range channelAccounts { + mChannelAccountStore. + On("GetAndLockAll", ctx, currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds, 1). + Return([]*store.ChannelAccount{acc}, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: acc.PublicKey}). + Return(horizon.Account{}, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil).Once() + mHorizonClient.On( + "SubmitTransactionWithOptions", + mock.Anything, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ).Return(horizon.Transaction{}, nil).Once() + mChannelAccountStore.On("Get", ctx, mock.Anything, acc.PublicKey, 0). + Return(acc, nil).Once() + mChannelAccountStore.On("DeleteIfLockedUntil", ctx, acc.PublicKey, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(nil).Once() + } + + err = cas.DeleteChannelAccount(ctx, opts) + require.NoError(t, err) + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_DeleteAccount_FindByPublicKey_Failure(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + opts := ChannelAccountServiceOptions{ + ChannelAccountID: "GDKMLSJSPHFWB26JV7ESWLJAKJ6KDTLQWYFT2T4ZVXFFHWBINUEJKASM", + DeleteAllAccounts: false, + } + + currLedgerNum := 1000 + + ctx := context.Background() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Once() + mChannelAccountStore.On("GetAndLock", ctx, opts.ChannelAccountID, currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(nil, errors.New("db error")).Once() + + err = cas.DeleteChannelAccount(ctx, opts) + require.ErrorContains(t, + err, + fmt.Sprintf("retrieving account %s from database in DeleteChannelAccount: db error", opts.ChannelAccountID), + ) + + mChannelAccountStore.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_DeleteAccount_DeleteFromDatabaseError(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + channelAccount := &store.ChannelAccount{ + PublicKey: "GAMWDQPPO3MXDQHZWYQLCQMKMBVDDCV7WIRKLCALWJPI7MIQHYNERTXS", + PrivateKey: "SBS2DJJSWZKKADWE4QEFN6CWXPM6KAFULKVJWO5VN7NIFDP6HFZXF6J7", + } + + opts := ChannelAccountServiceOptions{ + ChannelAccountID: channelAccount.PublicKey, + NetworkPassphrase: "Test SDF Network ; September 2015", + DeleteAllAccounts: false, + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + } + + currLedgerNum := 1000 + + ctx := context.Background() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Once() + mChannelAccountStore.On("GetAndLock", ctx, opts.ChannelAccountID, currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(channelAccount, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: opts.ChannelAccountID}). + Return(horizon.Account{}, horizonclient.Error{ + Problem: problem.P{ + Type: "https://stellar.org/horizon-errors/not_found", + }, + }).Once() + mChannelAccountStore. + On("DeleteIfLockedUntil", ctx, opts.ChannelAccountID, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(errors.New("db error")). + Once() + + err = cas.DeleteChannelAccount(ctx, opts) + require.Error(t, err) + require.ErrorContains( + t, + err, + fmt.Sprintf( + `deleting account %[1]s in DeleteChannelAccount: deleting %[1]s from signature service: deleting channel account "%[1]s" from database: db error`, + opts.ChannelAccountID, + ), + ) + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_DeleteAccount_SubmitTransaction_Failure(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + channelAccount := &store.ChannelAccount{ + PublicKey: "GDXSRISWI6ZVFVVOUU2DNKVHUYEJQZ63A37P6C5NGKXBROW5WW5W6HW3", + PrivateKey: "SDHGNWPVZJML64GMSQFVX7RAZBJXO3SWOMEGV77IPXUMKHHEOFD2LC75", + } + + opts := ChannelAccountServiceOptions{ + ChannelAccountID: channelAccount.PublicKey, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currLedgerNum := 1000 + + ctx := context.Background() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Once() + mChannelAccountStore.On("GetAndLock", ctx, opts.ChannelAccountID, currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(channelAccount, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: opts.ChannelAccountID}). + Return(horizon.Account{}, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil).Once() + mChannelAccountStore.On("Get", ctx, mock.Anything, opts.ChannelAccountID, 0). + Return(channelAccount, nil).Once() + mHorizonClient.On( + "SubmitTransactionWithOptions", + mock.Anything, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ).Return(horizon.Transaction{}, errors.New("horizon client error")).Once() + + err = cas.DeleteChannelAccount(ctx, opts) + assert.ErrorContains( + t, + err, + fmt.Sprintf( + "deleting account %[1]s in DeleteChannelAccount: deleting account %[1]s onchain: submitting remove account transaction to the network for account %[1]s: horizon client error", + opts.ChannelAccountID, + ), + ) + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_EnsureChannelAccounts_Exact_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + } + + opts := ChannelAccountServiceOptions{NumChannelAccounts: 2} + + ctx := context.Background() + mChannelAccountStore.On("Count", ctx). + Return(opts.NumChannelAccounts, nil).Once() + getEntries := log.DefaultLogger.StartTest(log.InfoLevel) + + err = cas.EnsureChannelAccountsCount(ctx, opts) + require.NoError(t, err) + + entries := getEntries() + assert.Equal(t, + entries[1].Message, + fmt.Sprintf("There are exactly %d managed channel accounts currently. Exiting...", opts.NumChannelAccounts), + ) + + mChannelAccountStore.AssertExpectations(t) +} + +func Test_ChannelAccounts_EnsureChannelAccounts_Add_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + desiredCount := 5 + opts := ChannelAccountServiceOptions{ + NumChannelAccounts: desiredCount, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currChannelAccountsCount := 2 + currLedgerNum := 100 + + ctx := context.Background() + mChannelAccountStore.On("Count", ctx).Return(currChannelAccountsCount, nil).Once() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil).Once() + mHorizonClient.On( + "SubmitTransactionWithOptions", + mock.Anything, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ).Return(horizon.Transaction{}, nil).Once() + mChannelAccountStore.On("BatchInsertAndLock", ctx, mock.AnythingOfType("[]*store.ChannelAccount"), currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(nil).Once() + mChannelAccountStore. + On("Get", ctx, mock.Anything, mock.AnythingOfType("string"), 0). + Return(&store.ChannelAccount{PrivateKey: keypair.MustRandom().Seed()}, nil). + Times(desiredCount - currChannelAccountsCount) + mChannelAccountStore.On("Unlock", ctx, mock.Anything, mock.AnythingOfType("string")).Return(nil, nil). + Times(desiredCount - currChannelAccountsCount) + + err = cas.EnsureChannelAccountsCount(ctx, opts) + require.NoError(t, err) + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_EnsureChannelAccounts_Delete_Success(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + mHorizonClient := &horizonclient.MockClient{} + mLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: mHorizonClient, + dbConnectionPool: dbConnectionPool, + ledgerNumberTracker: mLedgerNumberTracker, + } + + opts := ChannelAccountServiceOptions{ + NumChannelAccounts: 2, + MaxBaseFee: 100, + NetworkPassphrase: "Test SDF Network ; September 2015", + RootSeed: "SBMW2WDSVTGT2N2PCBF3PV7WBOIKVTGGIEBUUYMDX3CKTDD5HY3UIHV4", + } + + rootAccount := keypair.MustParseFull(opts.RootSeed) + currChannelAccountsCount := 4 + + channelAccounts := []*store.ChannelAccount{ + { + PublicKey: "GCCVRQS7R7V66QDPBZKHRVPOVPCG253BPUSYPWC4GZN54AVXIRHW4QYN", + PrivateKey: "SCDC7JG53WIFEHFI72KIS6PMMVFDNZDT32VRQY45JVE4FEYNTQYXMWWJ", + }, + { + PublicKey: "GDHVIPZMT6UWY2SNG7RBHK5P5NHXIIWMVINEARIO7QLBVNRJDYUNACDF", + PrivateKey: "SDRLEKUEM5535VWJSRPICXLVPOWPVSTVWFNQSVIJ6M3TPHXBQBGHWNJ2", + }, + } + + currLedgerNum := 1000 + + ctx := context.Background() + mChannelAccountStore.On("Count", ctx).Return(currChannelAccountsCount, nil).Once() + mLedgerNumberTracker.On("GetLedgerNumber").Return(currLedgerNum, nil).Times(currChannelAccountsCount - opts.NumChannelAccounts) + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: rootAccount.Address()}). + Return(horizon.Account{AccountID: rootAccount.Address()}, nil).Times(currChannelAccountsCount - opts.NumChannelAccounts) + + for _, acc := range channelAccounts { + mChannelAccountStore.On("GetAndLockAll", ctx, currLedgerNum, currLedgerNum+engine.IncrementForMaxLedgerBounds, 1). + Return([]*store.ChannelAccount{acc}, nil).Once() + mHorizonClient.On("AccountDetail", horizonclient.AccountRequest{AccountID: acc.PublicKey}). + Return(horizon.Account{}, nil).Once() + mChannelAccountStore.On("Get", ctx, mock.Anything, acc.PublicKey, 0). + Return(acc, nil).Once() + mChannelAccountStore.On("DeleteIfLockedUntil", ctx, acc.PublicKey, currLedgerNum+engine.IncrementForMaxLedgerBounds). + Return(nil).Once() + } + + mHorizonClient.On( + "SubmitTransactionWithOptions", + mock.Anything, + horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}, + ).Return(horizon.Transaction{}, nil).Times(currChannelAccountsCount - opts.NumChannelAccounts) + + err = cas.EnsureChannelAccountsCount(ctx, opts) + require.NoError(t, err) + + mChannelAccountStore.AssertExpectations(t) + mHorizonClient.AssertExpectations(t) + mLedgerNumberTracker.AssertExpectations(t) +} + +func Test_ChannelAccounts_ViewChannelAccounts_Success(t *testing.T) { + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: &horizonclient.MockClient{}, + } + + channelAccounts := []*store.ChannelAccount{ + { + PublicKey: "GDTQYQQSQ5AG6ZYERKU5VH3RBPEZ33U5HGYM6SPUY42QULOQIC2MRZ3N", + }, + { + PublicKey: "GDXSRISWI6ZVFVVOUU2DNKVHUYEJQZ63A37P6C5NGKXBROW5WW5W6HW3", + }, + { + PublicKey: "GAR7SZWK2GV23OGIQC2BBZUUDSVSMT3MUOY7NJLJ75W5OJ3KQUR7VAIV", + }, + } + + ctx := context.Background() + mChannelAccountStore.On("GetAll", ctx, mock.Anything, 0, 0).Return(channelAccounts, nil).Once() + getEntries := log.DefaultLogger.StartTest(log.InfoLevel) + + err := cas.ViewChannelAccounts(ctx) + require.NoError(t, err) + + entries := getEntries() + for i, entry := range entries[1:] { + assert.Equal( + t, + entry.Message, + fmt.Sprintf("Found account %s", channelAccounts[i].PublicKey), + ) + } + + mChannelAccountStore.AssertExpectations(t) +} + +func Test_ChannelAccounts_ViewChannelAccounts_LoadChannelAccountsError_Failure(t *testing.T) { + mChannelAccountStore := &storeMocks.MockChannelAccountStore{} + + cas := ChannelAccountsService{ + caStore: mChannelAccountStore, + horizonClient: &horizonclient.MockClient{}, + } + ctx := context.Background() + mChannelAccountStore.On("GetAll", ctx, mock.Anything, 0, 0). + Return(nil, errors.New("db error")).Once() + + err := cas.ViewChannelAccounts(ctx) + require.EqualError(t, err, "loading channel accounts from database in ViewChannelAccounts: db error") + + mChannelAccountStore.AssertExpectations(t) +} diff --git a/internal/transactionsubmission/services/mocks.go b/internal/transactionsubmission/services/mocks.go new file mode 100644 index 000000000..e7e34ad84 --- /dev/null +++ b/internal/transactionsubmission/services/mocks.go @@ -0,0 +1,43 @@ +package services + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type ChannelAccountsServiceMock struct { + mock.Mock +} + +func (cas *ChannelAccountsServiceMock) CreateChannelAccountsOnChain(ctx context.Context, opts ChannelAccountServiceOptions) error { + args := cas.Called(ctx, opts) + return args.Error(0) +} + +func (cas *ChannelAccountsServiceMock) VerifyChannelAccounts(ctx context.Context, opts ChannelAccountServiceOptions) error { + args := cas.Called(ctx) + return args.Error(0) +} + +func (cas *ChannelAccountsServiceMock) DeleteChannelAccounts(ctx context.Context) error { + args := cas.Called(ctx) + return args.Error(0) +} + +func (cas *ChannelAccountsServiceMock) DeleteChannelAccount(ctx context.Context, opts ChannelAccountServiceOptions) error { + args := cas.Called(ctx, opts) + return args.Error(0) +} + +func (cas *ChannelAccountsServiceMock) EnsureChannelAccountsCount(ctx context.Context, opts ChannelAccountServiceOptions) error { + args := cas.Called(ctx, opts) + return args.Error(0) +} + +func (cas *ChannelAccountsServiceMock) ViewChannelAccounts(ctx context.Context) error { + args := cas.Called(ctx) + return args.Error(0) +} + +var _ ChannelAccountsServiceInterface = (*ChannelAccountsServiceMock)(nil) diff --git a/internal/transactionsubmission/store/channel_account.go b/internal/transactionsubmission/store/channel_account.go new file mode 100644 index 000000000..f6a70817b --- /dev/null +++ b/internal/transactionsubmission/store/channel_account.go @@ -0,0 +1,371 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/lib/pq" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +type ChannelAccount struct { + PublicKey string `db:"public_key"` + PrivateKey string `db:"private_key"` // TODO: remove this from the model, since we now rely on a Signer interface. + UpdatedAt *time.Time `db:"updated_at"` + CreatedAt *time.Time `db:"created_at"` + LockedAt sql.NullTime `db:"locked_at"` + // LockedUntilLedgerNumber is the ledger number after which the lock expires. It should be synched with the + // expiration ledger bound of the transaction submitted by this Stellar channel account. + LockedUntilLedgerNumber sql.NullInt32 `db:"locked_until_ledger_number"` +} + +func (ca *ChannelAccount) IsLocked(currentLedgerNumber int32) bool { + return ca.LockedUntilLedgerNumber.Valid && currentLedgerNumber <= ca.LockedUntilLedgerNumber.Int32 +} + +type ChannelAccountModel struct { + DBConnectionPool db.DBConnectionPool +} + +func NewChannelAccountModel(dbConnectionPool db.DBConnectionPool) *ChannelAccountModel { + return &ChannelAccountModel{DBConnectionPool: dbConnectionPool} +} + +// Insert inserts a (publicKey, privateKey) pair to the database. +func (ca *ChannelAccountModel) Insert(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, privateKey string) error { + err := ca.BatchInsert(ctx, sqlExec, []*ChannelAccount{{PublicKey: publicKey, PrivateKey: privateKey}}) + if err != nil { + return fmt.Errorf("inserting channel account %q: %w", publicKey, err) + } + + return nil +} + +// BatchInsert inserts a a batch of (publicKey, privateKey) pairs into the database. +func (ca *ChannelAccountModel) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error { + if len(channelAccounts) == 0 { + return nil + } + + publicKeys := make([]string, len(channelAccounts)) + privateKeys := make([]string, len(channelAccounts)) + + for i, chAcc := range channelAccounts { + if chAcc.PublicKey == "" { + return fmt.Errorf("public key cannot be empty") + } + if chAcc.PrivateKey == "" { + return fmt.Errorf("private key cannot be empty") + } + + publicKeys[i] = chAcc.PublicKey + privateKeys[i] = chAcc.PrivateKey + } + + const q = ` + INSERT INTO + channel_accounts (public_key, private_key) + SELECT * + FROM UNNEST($1::text[], $2::text[]) + ` + + _, err := sqlExec.ExecContext(ctx, q, pq.Array(publicKeys), pq.Array(privateKeys)) + if err != nil { + return fmt.Errorf("inserting channel accounts: %w", err) + } + + return nil +} + +// InsertAndLock insert an account keypair into the database and locks it until some future ledger. +func (ca *ChannelAccountModel) InsertAndLock(ctx context.Context, publicKey string, privateKey string, currentLedger, nextLedgerLock int) error { + return db.RunInTransaction(ctx, ca.DBConnectionPool, nil, func(dbTx db.DBTransaction) error { + err := ca.Insert(ctx, dbTx, publicKey, privateKey) + if err != nil { + return fmt.Errorf("cannot insert account %s: %w", publicKey, err) + } + + _, err = ca.Lock(ctx, dbTx, publicKey, int32(currentLedger), int32(nextLedgerLock)) + if err != nil { + return fmt.Errorf("cannot lock account %s: %w", publicKey, err) + } + + return nil + }) +} + +// BatchInsertAndLock inserts a batch of account keypairs into the database and locks them until some future ledger. +func (ca *ChannelAccountModel) BatchInsertAndLock(ctx context.Context, channelAccounts []*ChannelAccount, currentLedger, nextLedgerLock int) error { + return db.RunInTransaction(ctx, ca.DBConnectionPool, nil, func(dbTx db.DBTransaction) error { + err := ca.BatchInsert(ctx, dbTx, channelAccounts) + if err != nil { + return fmt.Errorf("cannot insert batch insert %d accounts: %w", len(channelAccounts), err) + } + + for _, account := range channelAccounts { + _, err = ca.Lock(ctx, dbTx, account.PublicKey, int32(currentLedger), int32(nextLedgerLock)) + if err != nil { + return fmt.Errorf("cannot lock account %s: %w", account.PublicKey, err) + } + } + + return nil + }) +} + +// Get retrieves the channel account with the given public key from the database if account is not locked or `currentLedgerNumber` is +// ahead of the ledger number the account has been locked to. +func (ca *ChannelAccountModel) Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, currentLedgerNumber int) (*ChannelAccount, error) { + query := ` + SELECT + * + FROM + channel_accounts + WHERE + public_key = $1%s + FOR UPDATE SKIP LOCKED + ` + + if currentLedgerNumber > 0 { + query = fmt.Sprintf(query, "\nAND "+ca.queryFilterForLockedState(false, int32(currentLedgerNumber))) + } else if currentLedgerNumber == 0 { + // bypass locked until ledger check for read-only purposes such as retrieving the keypair for signing + query = fmt.Sprintf(query, "") + } else { + return nil, fmt.Errorf("invalid ledger number %d", currentLedgerNumber) + } + + var channelAccount ChannelAccount + err := sqlExec.GetContext(ctx, &channelAccount, query, publicKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("could not find channel account %q: %w", publicKey, ErrRecordNotFound) + } + return nil, fmt.Errorf("querying for channel account %q: %w", publicKey, err) + } + + return &channelAccount, nil +} + +// GetAndLock retrieves the channel account with the given public key from the database and locks the account until some future ledger. +func (ca *ChannelAccountModel) GetAndLock(ctx context.Context, publicKey string, currentLedger, nextLedgerLock int) (*ChannelAccount, error) { + channelAccount, err := ca.Get(ctx, ca.DBConnectionPool, publicKey, currentLedger) + if err != nil { + return nil, fmt.Errorf("cannot retrieve account %s: %w", publicKey, err) + } + + lockedAccount, err := ca.Lock(ctx, ca.DBConnectionPool, channelAccount.PublicKey, int32(currentLedger), int32(nextLedgerLock)) + if err != nil { + return nil, fmt.Errorf("cannot lock account %s: %w", channelAccount.PublicKey, err) + } + + return lockedAccount, nil +} + +// Count retrieves the current count of channel accounts in the database. +func (ca *ChannelAccountModel) Count(ctx context.Context) (int, error) { + query := ` + SELECT + COUNT(*) + FROM + channel_accounts + ` + + var count int + err := ca.DBConnectionPool.GetContext(ctx, &count, query) + if err != nil { + return 0, fmt.Errorf("counting channel accounts: %w", err) + } + + return count, nil +} + +// GetAll all channel accounts from the database, respecting the limit provided for accounts that are not locked or `currentLedgerNumber` is +// ahead of the ledger number each account has been locked to. +func (ca *ChannelAccountModel) GetAll(ctx context.Context, sqlExec db.SQLExecuter, currentLedgerNumber, limit int) ([]*ChannelAccount, error) { + baseQuery := ` + SELECT + * + FROM + channel_accounts%s + FOR UPDATE SKIP LOCKED + ` + + if currentLedgerNumber > 0 { + baseQuery = fmt.Sprintf(baseQuery, "\nWHERE"+ca.queryFilterForLockedState(false, int32(currentLedgerNumber))) + } else if currentLedgerNumber == 0 { + // bypass locked until ledger check for read-only purposes such as retrieving the keypair for signing + baseQuery = fmt.Sprintf(baseQuery, "") + } else { + return nil, fmt.Errorf("invalid ledger number %d", currentLedgerNumber) + } + + query, params := ca.newLoadChannelAccountsLimitFromDatabase(baseQuery, limit) + + var accounts []*ChannelAccount + err := sqlExec.SelectContext(ctx, &accounts, query, params...) + if err != nil { + return nil, fmt.Errorf("loading channel accounts from database: %w", err) + } + + return accounts, nil +} + +// GetAndLockAll retrieves all channel account that are not already locked from the database and locks them until some future ledger. +func (ca *ChannelAccountModel) GetAndLockAll(ctx context.Context, currentLedger, nextLedgerLock, limit int) ([]*ChannelAccount, error) { + channelAccounts, err := ca.GetAll(ctx, ca.DBConnectionPool, currentLedger, limit) + if err != nil { + return nil, fmt.Errorf("cannot retrieve accounts for locking: %w", err) + } + if len(channelAccounts) == 0 { + return nil, fmt.Errorf("no channel accounts available to retrieve") + } + + var updatedChannelAccounts []*ChannelAccount + for _, channelAccount := range channelAccounts { + lockedAccount, err := ca.Lock(ctx, ca.DBConnectionPool, channelAccount.PublicKey, int32(currentLedger), int32(nextLedgerLock)) + if err != nil { + return nil, fmt.Errorf("cannot lock account %s: %w", channelAccount.PublicKey, err) + } + + updatedChannelAccounts = append(updatedChannelAccounts, lockedAccount) + } + + return updatedChannelAccounts, nil +} + +// newLoadChannelAccountsLimitFromDatabase returns a query that limits the number of channel accounts retrieved if limit>0, +// or retrieves all channel accounts if limit=0. +func (ca *ChannelAccountModel) newLoadChannelAccountsLimitFromDatabase( + baseQuery string, limit int, +) (query string, params []interface{}) { + qb := data.NewQueryBuilder(baseQuery) + if limit > 0 { + qb.AddPagination(1, limit) + } + query, params = qb.Build() + return ca.DBConnectionPool.Rebind(query), params +} + +// Delete deletes a channel account with the provided publicKey from the database. +func (ca *ChannelAccountModel) Delete(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) error { + query := ` + DELETE + FROM + channel_accounts + WHERE + public_key = $1 + ` + + res, err := sqlExec.ExecContext(ctx, query, publicKey) + if err != nil { + return fmt.Errorf("deleting channel account %q: %w", publicKey, err) + } + + numRowsAffected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("getting number of rows affected: %w", err) + } + + if numRowsAffected == 0 { + return fmt.Errorf("could not find nor delete account %q: %w", publicKey, ErrRecordNotFound) + } else if numRowsAffected != 1 { + return fmt.Errorf("unexpected number of rows affected: %d when deleting channel account %s", numRowsAffected, publicKey) + } + + return nil +} + +// DeleteIfLockedUntil deletes a channel account with the provided publicKey from the database only if the provided +// `lockedUntilLedgerNumber` matches the value of the same field on the channel account. Also, if the account has not been +// locked previously, does not proceed with the deletion. +func (ca *ChannelAccountModel) DeleteIfLockedUntil(ctx context.Context, publicKey string, lockedUntilLedgerNumber int) error { + return db.RunInTransaction(ctx, ca.DBConnectionPool, nil, func(dbTx db.DBTransaction) error { + account, err := ca.Get(ctx, dbTx, publicKey, 0) + if err != nil { + return fmt.Errorf("cannot retrieve account %s: %w", publicKey, err) + } + + if !(account.LockedUntilLedgerNumber.Valid && account.LockedUntilLedgerNumber.Int32 == int32(lockedUntilLedgerNumber)) { + return fmt.Errorf("cannot delete account due to locked until ledger number mismatch or field being null") + } + + _, err = ca.Unlock(ctx, dbTx, account.PublicKey) + if err != nil { + return fmt.Errorf("cannot unlock account for deletion %s: %w", account.PublicKey, err) + } + + err = ca.Delete(ctx, dbTx, account.PublicKey) + if err != nil { + return fmt.Errorf("cannot delete account %s: %w", account.PublicKey, err) + } + + return nil + }) +} + +// queryFilterForLockedState returns a SQL query filter that can be used to filter channel accounts based on their +// locked state. +func (ca *ChannelAccountModel) queryFilterForLockedState(locked bool, ledgerNumber int32) string { + if locked { + return fmt.Sprintf("(locked_until_ledger_number >= %d)", ledgerNumber) + } + return fmt.Sprintf("(locked_until_ledger_number IS NULL OR locked_until_ledger_number < %d)", ledgerNumber) +} + +// Lock locks the channel account with the provided publicKey. It returns a ErrRecordNotFound error if you try to lock a +// channel account that is already locked. +func (ca *ChannelAccountModel) Lock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, currentLedger, nextLedgerLock int32) (*ChannelAccount, error) { + q := fmt.Sprintf(` + UPDATE + channel_accounts + SET + locked_at = NOW(), + locked_until_ledger_number = $1 + WHERE + public_key = $2 + AND %s + RETURNING * + `, ca.queryFilterForLockedState(false, currentLedger)) + var channelAccount ChannelAccount + err := sqlExec.GetContext(ctx, &channelAccount, q, nextLedgerLock, publicKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("locking channel account %q: %w", publicKey, err) + } + + return &channelAccount, nil +} + +// Unlock lifts the lock from the channel account with the provided publicKey. +func (ca *ChannelAccountModel) Unlock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*ChannelAccount, error) { + q := ` + UPDATE + channel_accounts + SET + locked_at = NULL, + locked_until_ledger_number = NULL + WHERE + public_key = $1 + RETURNING * + ` + var channelAccount ChannelAccount + err := sqlExec.GetContext(ctx, &channelAccount, q, publicKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("unlocking channel account %q: %w", publicKey, err) + } + + return &channelAccount, nil +} + +var _ ChannelAccountStore = &ChannelAccountModel{} diff --git a/internal/transactionsubmission/store/channel_account_test.go b/internal/transactionsubmission/store/channel_account_test.go new file mode 100644 index 000000000..e5e2e615e --- /dev/null +++ b/internal/transactionsubmission/store/channel_account_test.go @@ -0,0 +1,737 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/stellar/go/keypair" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ChannelAccount_IsLocked(t *testing.T) { + const currentLedgerNumber = 10 + + testCases := []struct { + name string + lockedUntilLedgerNumber sql.NullInt32 + wantResult bool + }{ + { + name: "returns false if lockedUntilLedgerNumber is null", + lockedUntilLedgerNumber: sql.NullInt32{}, + wantResult: false, + }, + { + name: "returns false if lockedUntilLedgerNumber is lower than currentLedgerNumber", + lockedUntilLedgerNumber: sql.NullInt32{Int32: currentLedgerNumber - 1, Valid: true}, + wantResult: false, + }, + { + name: "returns true if lockedUntilLedgerNumber is equal to currentLedgerNumber", + lockedUntilLedgerNumber: sql.NullInt32{Int32: currentLedgerNumber, Valid: true}, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ca := &ChannelAccount{LockedUntilLedgerNumber: tc.lockedUntilLedgerNumber} + assert.Equal(t, tc.wantResult, ca.IsLocked(currentLedgerNumber)) + }) + } +} + +func Test_ChannelAccountModel_BatchInsert_GetAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + caModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 1 + const nextLedgerLock int32 = 11 + + testCases := []struct { + name string + chAccounts []*ChannelAccount + wantErrContains string + queryAtLedger int + lockAccounts bool + }{ + { + name: "empty accounts won't return error and won't create any records", + }, + { + name: "returns error if a public key is empty", + chAccounts: []*ChannelAccount{ + { + PublicKey: "", + PrivateKey: "SAIXHVEDXDEO37PUD7SAJU2BPZGRP43EI3FOPHP4L7AP3LICY6AMIR6T", + }, + }, + wantErrContains: "public key cannot be empty", + }, + { + name: "returns error if a private key is empty", + chAccounts: []*ChannelAccount{ + { + PublicKey: "GCFZGYGGXEMPJNL52QX2DXG2X5ZHJ3XTEWAUBWXQE2PXX7V532AI4ALT", + PrivateKey: "", + }, + }, + wantErrContains: "private key cannot be empty", + }, + { + name: "πŸŽ‰ successfully insert one channel account", + chAccounts: []*ChannelAccount{ + { + PublicKey: "GCFZGYGGXEMPJNL52QX2DXG2X5ZHJ3XTEWAUBWXQE2PXX7V532AI4ALT", + PrivateKey: "SAIXHVEDXDEO37PUD7SAJU2BPZGRP43EI3FOPHP4L7AP3LICY6AMIR6T", + }, + }, + }, + { + name: "returns 0 channel accounts when querying at ledger number before accounts are unlocked", + chAccounts: []*ChannelAccount{ + { + PublicKey: "GCFZGYGGXEMPJNL52QX2DXG2X5ZHJ3XTEWAUBWXQE2PXX7V532AI4ALT", + PrivateKey: "SAIXHVEDXDEO37PUD7SAJU2BPZGRP43EI3FOPHP4L7AP3LICY6AMIR6T", + }, + { + PublicKey: "GAL3MHT7SWJXV33JHK2BENHUVUZLENMJFYOLJU4CLI3723MDSRJL5AJM", + PrivateKey: "SBHQLRTVR2HKLRE5UKKV2VIZIR7VHZQ6375KWOKU3E6H2AKE374VICXQ", + }, + }, + queryAtLedger: 5, + lockAccounts: true, + }, + { + name: "πŸŽ‰ successfully insert multiple channel accounts", + chAccounts: []*ChannelAccount{ + { + PublicKey: "GCFZGYGGXEMPJNL52QX2DXG2X5ZHJ3XTEWAUBWXQE2PXX7V532AI4ALT", + PrivateKey: "SAIXHVEDXDEO37PUD7SAJU2BPZGRP43EI3FOPHP4L7AP3LICY6AMIR6T", + }, + { + PublicKey: "GAL3MHT7SWJXV33JHK2BENHUVUZLENMJFYOLJU4CLI3723MDSRJL5AJM", + PrivateKey: "SBHQLRTVR2HKLRE5UKKV2VIZIR7VHZQ6375KWOKU3E6H2AKE374VICXQ", + }, + }, + }, + } + + type comparableChAccount struct { + PublicKey string + PrivateKey string + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + batchInsertErr := caModel.BatchInsert(ctx, caModel.DBConnectionPool, tc.chAccounts) + + if tc.lockAccounts { + for _, ca := range tc.chAccounts { + _, err = caModel.Lock(ctx, caModel.DBConnectionPool, ca.PublicKey, currentLedger, nextLedgerLock) + require.NoError(t, err) + } + } + + allChAccounts, getAllErr := caModel.GetAll(ctx, caModel.DBConnectionPool, tc.queryAtLedger, 0) + require.NoError(t, getAllErr) + + if tc.wantErrContains != "" { + require.Error(t, batchInsertErr) + assert.ErrorContains(t, batchInsertErr, tc.wantErrContains) + } else if tc.lockAccounts { + require.NoError(t, err) + assert.Len(t, allChAccounts, 0) + } else { + require.NoError(t, batchInsertErr) + assert.Equal(t, len(tc.chAccounts), len(allChAccounts)) + + // compare the accounts + var allChAccountsComparable []comparableChAccount + for _, chAccount := range allChAccounts { + allChAccountsComparable = append(allChAccountsComparable, comparableChAccount{ + PublicKey: chAccount.PublicKey, + PrivateKey: chAccount.PrivateKey, + }) + } + + var tcChAccountsComparable []comparableChAccount + for _, chAccount := range tc.chAccounts { + tcChAccountsComparable = append(tcChAccountsComparable, comparableChAccount{ + PublicKey: chAccount.PublicKey, + PrivateKey: chAccount.PrivateKey, + }) + } + + assert.ElementsMatch(t, tcChAccountsComparable, allChAccountsComparable) + } + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_Insert_Get(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + caModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 1 + const nextLedgerLock int32 = 11 + + testCases := []struct { + name string + channelAccounts []*ChannelAccount + publicKeysToQuery []string + expectedErrorFormat string + lockAccounts bool + queryAtLedger int + errorType string + }{ + { + name: "can insert and get a channel account", + channelAccounts: []*ChannelAccount{ + { + PublicKey: "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + PrivateKey: "SBXVHYY2VXHTXGHSQ4VXC7LSUUECQY633CZTY5Q6JCYRP5KQC4WRWU25", + }, + { + PublicKey: "GCXLO7JS3X7H45ZQEJIA2NQPCAPGQW3TSYGPWUUSBXDXDGMZZFHEWZSU", + PrivateKey: "SCYDT4TJF43OAO3TYQQAWKEPOGJSBXZGW3WVQZGTOXVVCSM5TFTCJQRZ", + }, + }, + publicKeysToQuery: []string{ + "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + "GCXLO7JS3X7H45ZQEJIA2NQPCAPGQW3TSYGPWUUSBXDXDGMZZFHEWZSU", + }, + expectedErrorFormat: "", + }, + { + name: "can get channel account at valid ledger number", + channelAccounts: []*ChannelAccount{ + { + PublicKey: "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + PrivateKey: "SBXVHYY2VXHTXGHSQ4VXC7LSUUECQY633CZTY5Q6JCYRP5KQC4WRWU25", + }, + }, + publicKeysToQuery: []string{ + "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + }, + queryAtLedger: 12, + lockAccounts: true, + }, + { + name: "returns an error when trying to get a channel account that does not exist", + channelAccounts: []*ChannelAccount{}, + publicKeysToQuery: []string{ + "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + "GCXLO7JS3X7H45ZQEJIA2NQPCAPGQW3TSYGPWUUSBXDXDGMZZFHEWZSU", + }, + expectedErrorFormat: "could not find channel account %q: record not found", + errorType: "invalid acocunt", + }, + { + name: "returns an error when querying at invalid ledger number", + channelAccounts: []*ChannelAccount{ + { + PublicKey: "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + PrivateKey: "SBXVHYY2VXHTXGHSQ4VXC7LSUUECQY633CZTY5Q6JCYRP5KQC4WRWU25", + }, + }, + publicKeysToQuery: []string{ + "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + }, + queryAtLedger: -1, + expectedErrorFormat: "invalid ledger number %d", + errorType: "invalid ledger number", + }, + { + name: "returns an error when querying for locked channel account", + channelAccounts: []*ChannelAccount{ + { + PublicKey: "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + PrivateKey: "SBXVHYY2VXHTXGHSQ4VXC7LSUUECQY633CZTY5Q6JCYRP5KQC4WRWU25", + }, + }, + publicKeysToQuery: []string{ + "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + }, + queryAtLedger: 5, + expectedErrorFormat: "could not find channel account %q: record not found", + errorType: "locked channel account", + lockAccounts: true, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + for _, ca := range test.channelAccounts { + err = caModel.Insert(ctx, caModel.DBConnectionPool, ca.PublicKey, ca.PrivateKey) + require.NoError(t, err) + } + + for _, pubKey := range test.publicKeysToQuery { + // lock accounts if queryAtLedger specified + if test.lockAccounts { + _, err = caModel.Lock(ctx, caModel.DBConnectionPool, pubKey, currentLedger, nextLedgerLock) + require.NoError(t, err) + } + + cam, err := caModel.Get(ctx, caModel.DBConnectionPool, pubKey, test.queryAtLedger) + if test.expectedErrorFormat != "" { + require.Error(t, err) + switch test.errorType { + case "invalid account", "locked channel account": + assert.EqualError(t, err, fmt.Sprintf(test.expectedErrorFormat, pubKey)) + case "invalid ledger number": + assert.EqualError(t, err, fmt.Sprintf(test.expectedErrorFormat, test.queryAtLedger)) + } + + } else { + require.NoError(t, err) + assert.Equal(t, pubKey, cam.PublicKey) + } + } + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_Insert_Count(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + caModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + testCases := []struct { + numChannelAccounts int + }{ + {numChannelAccounts: 0}, + {numChannelAccounts: 1}, + {numChannelAccounts: 10}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%d channel cccount(s)", tc.numChannelAccounts), func(t *testing.T) { + for range make([]interface{}, tc.numChannelAccounts) { + kp, err := keypair.Random() + require.NoError(t, err) + err = caModel.Insert(ctx, caModel.DBConnectionPool, kp.Address(), kp.Seed()) + require.NoError(t, err) + } + + count, err := caModel.Count(ctx) + require.NoError(t, err) + assert.Equal(t, tc.numChannelAccounts, count) + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_Insert_Delete(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + caModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + ca := &ChannelAccount{ + PublicKey: "GDLYOWHAC2U4I52OXDEWMEAVNR6WLML3LIG32QOOLKWPCC233OBSKVU5", + PrivateKey: "SBXVHYY2VXHTXGHSQ4VXC7LSUUECQY633CZTY5Q6JCYRP5KQC4WRWU25", + } + + testCases := []struct { + name string + channelAccountToAdd *ChannelAccount + channelAccountToDelete *ChannelAccount + expectedErrorFormat string + }{ + { + name: "add and delete channel account", + channelAccountToAdd: ca, + channelAccountToDelete: ca, + }, + { + name: "returns an error when trying to delete a channel account that does not exist", + channelAccountToAdd: ca, + channelAccountToDelete: &ChannelAccount{ + PublicKey: "GCXLO7JS3X7H45ZQEJIA2NQPCAPGQW3TSYGPWUUSBXDXDGMZZFHEWZSU", + PrivateKey: "SCYDT4TJF43OAO3TYQQAWKEPOGJSBXZGW3WVQZGTOXVVCSM5TFTCJQRZ", + }, + expectedErrorFormat: "could not find nor delete account %q: record not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = caModel.Insert(ctx, caModel.DBConnectionPool, tc.channelAccountToAdd.PublicKey, tc.channelAccountToAdd.PrivateKey) + require.NoError(t, err) + + err = caModel.Delete(ctx, caModel.DBConnectionPool, tc.channelAccountToDelete.PublicKey) + if tc.expectedErrorFormat != "" { + require.Error(t, err) + assert.EqualError(t, fmt.Errorf(tc.expectedErrorFormat, tc.channelAccountToDelete.PublicKey), err.Error()) + } else { + require.NoError(t, err) + } + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_queryFilterForLockedState(t *testing.T) { + chAccModel := &ChannelAccountModel{} + + testCases := []struct { + name string + locked bool + ledgerNumber int32 + wantFilter string + }{ + { + name: "locked to ledgerNumber=10", + locked: true, + ledgerNumber: 10, + wantFilter: "(locked_until_ledger_number >= 10)", + }, + { + name: "unlocked or expired on ledgerNumber=20", + locked: false, + ledgerNumber: 20, + wantFilter: "(locked_until_ledger_number IS NULL OR locked_until_ledger_number < 20)", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotFilter := chAccModel.queryFilterForLockedState(tc.locked, tc.ledgerNumber) + assert.Equal(t, tc.wantFilter, gotFilter) + }) + } +} + +func Test_ChannelAccountModel_Lock(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 10 + const nextLedgerLock int32 = 20 + + testCases := []struct { + name string + initialLockedAt sql.NullTime + initialLockedUntilLedger sql.NullInt32 + expectedErrContains string + }{ + { + name: "πŸŽ‰ successfully locks channel account without any previous lock", + }, + { + name: "πŸŽ‰ successfully locks channel account with lock expired", + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger - 1, Valid: true}, + }, + { + name: "🚧 cannot be locked again if still locked", + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger, Valid: true}, + expectedErrContains: ErrRecordNotFound.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + channelAccount := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + q := `UPDATE channel_accounts SET locked_at = $1, locked_until_ledger_number = $2 WHERE public_key = $3` + _, err := dbConnectionPool.ExecContext(ctx, q, tc.initialLockedAt, tc.initialLockedUntilLedger, channelAccount.PublicKey) + require.NoError(t, err) + + channelAccount, err = chAccModel.Lock(ctx, dbConnectionPool, channelAccount.PublicKey, currentLedger, nextLedgerLock) + + if tc.expectedErrContains == "" { + require.NoError(t, err) + channelAccount, err = chAccModel.Get(ctx, chAccModel.DBConnectionPool, channelAccount.PublicKey, 0) + require.NoError(t, err) + assert.True(t, channelAccount.LockedAt.Valid) + assert.True(t, channelAccount.LockedUntilLedgerNumber.Valid) + assert.Equal(t, nextLedgerLock, channelAccount.LockedUntilLedgerNumber.Int32) + + var channelAccountRefreshed *ChannelAccount + channelAccountRefreshed, err = chAccModel.Get(ctx, chAccModel.DBConnectionPool, channelAccount.PublicKey, 0) + require.NoError(t, err) + require.Equal(t, *channelAccountRefreshed, *channelAccount) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, tc.expectedErrContains) + } + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_Unlock(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 10 + + testCases := []struct { + name string + initialLockedAt sql.NullTime + initialLockedUntilLedger sql.NullInt32 + }{ + { + name: "πŸŽ‰ successfully unlocks channel account that were not locked", + }, + { + name: "πŸŽ‰ successfully unlocks channel account whose lock was expired", + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger - 1, Valid: true}, + }, + { + name: "πŸŽ‰ successfully unlocks locked channel account", + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger, Valid: true}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + channelAccount := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + q := `UPDATE channel_accounts SET locked_at = $1, locked_until_ledger_number = $2 WHERE public_key = $3` + _, err := dbConnectionPool.ExecContext(ctx, q, tc.initialLockedAt, tc.initialLockedUntilLedger, channelAccount.PublicKey) + require.NoError(t, err) + + channelAccount, err = chAccModel.Unlock(ctx, dbConnectionPool, channelAccount.PublicKey) + require.NoError(t, err) + assert.False(t, channelAccount.LockedAt.Valid) + assert.False(t, channelAccount.LockedUntilLedgerNumber.Valid) + + channelAccountRefreshed, err := chAccModel.Get(ctx, chAccModel.DBConnectionPool, channelAccount.PublicKey, 0) + require.NoError(t, err) + require.Equal(t, *channelAccountRefreshed, *channelAccount) + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_Lock_Unlock(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 10 + const nextLedgerLock int32 = 20 + + // On creation, channel account is unlocked + channelAccount := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + assert.False(t, channelAccount.IsLocked(currentLedger)) + + count := 3 + for range make(sql.RawBytes, count) { + // Lock channel account + channelAccount, err = chAccModel.Lock(ctx, dbConnectionPool, channelAccount.PublicKey, currentLedger, nextLedgerLock) + require.NoError(t, err) + assert.True(t, channelAccount.IsLocked(currentLedger)) + + channelAccountRefreshed, err := chAccModel.Get(ctx, chAccModel.DBConnectionPool, channelAccount.PublicKey, 0) + require.NoError(t, err) + require.Equal(t, *channelAccountRefreshed, *channelAccount) + + // Unlock channel account + channelAccount, err = chAccModel.Unlock(ctx, dbConnectionPool, channelAccount.PublicKey) + require.NoError(t, err) + assert.False(t, channelAccount.IsLocked(currentLedger)) + + channelAccountRefreshed, err = chAccModel.Get(ctx, chAccModel.DBConnectionPool, channelAccount.PublicKey, 0) + require.NoError(t, err) + require.Equal(t, *channelAccountRefreshed, *channelAccount) + + count-- + } + + assert.Equal(t, 0, count) + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) +} + +func Test_ChannelAccountModel_DeleteIfLockedUntil(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + lockedToLedger := 100 + testCases := []struct { + name string + accountLockedUntilLedgerNumber int + deleteAtLedgerNumber int + expectedErrContains string + }{ + { + name: "returns error if delete at ledger number different from locked until ledger number", + accountLockedUntilLedgerNumber: lockedToLedger, + deleteAtLedgerNumber: lockedToLedger + 1, + expectedErrContains: "cannot delete account due to locked until ledger number mismatch or field being null", + }, + { + name: "returns error if account not locked to ledger", + deleteAtLedgerNumber: lockedToLedger, + expectedErrContains: "cannot delete account due to locked until ledger number mismatch or field being null", + }, + { + name: "successfully delete at ledger number", + accountLockedUntilLedgerNumber: lockedToLedger, + deleteAtLedgerNumber: lockedToLedger, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + channelAccount := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + if tc.accountLockedUntilLedgerNumber != 0 { + _, lockErr := chAccModel.Lock( + ctx, + chAccModel.DBConnectionPool, + channelAccount.PublicKey, + int32(tc.accountLockedUntilLedgerNumber), + int32(tc.accountLockedUntilLedgerNumber), + ) + require.NoError(t, lockErr) + } + + err = chAccModel.DeleteIfLockedUntil(ctx, channelAccount.PublicKey, tc.deleteAtLedgerNumber) + if tc.expectedErrContains != "" { + require.ErrorContains(t, err, tc.expectedErrContains) + } else { + require.NoError(t, err) + } + + DeleteAllFromChannelAccounts(t, ctx, chAccModel.DBConnectionPool) + }) + } +} + +func Test_ChannelAccountModel_GetAndLockAll(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + t.Run("try to lock all accounts that have already been locked", func(t *testing.T) { + currLedgerNumber := int32(1) + channelAccounts := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 3) + for _, account := range channelAccounts { + _, err := chAccModel.Lock(ctx, chAccModel.DBConnectionPool, account.PublicKey, currLedgerNumber, currLedgerNumber+10) + require.NoError(t, err) + } + + _, err := chAccModel.GetAndLockAll(ctx, int(currLedgerNumber), int(currLedgerNumber+5), 0) + require.EqualError(t, err, "no channel accounts available to retrieve") + + DeleteAllFromChannelAccounts(t, ctx, chAccModel.DBConnectionPool) + }) + + t.Run("get and lock all available accounts", func(t *testing.T) { + currLedgerNumber := int32(1) + lockToLedgerNumber := currLedgerNumber + 10 + CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 3) + + updatedChannelAccounts, err := chAccModel.GetAndLockAll(ctx, int(currLedgerNumber), int(lockToLedgerNumber), 0) + require.NoError(t, err) + for _, account := range updatedChannelAccounts { + assert.Equal(t, account.LockedUntilLedgerNumber.Int32, int32(lockToLedgerNumber)) + } + + DeleteAllFromChannelAccounts(t, ctx, chAccModel.DBConnectionPool) + }) +} + +func Test_ChannelAccountModel_GetAndLock(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + chAccModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + t.Run("try to lock an account that has already been locked", func(t *testing.T) { + currLedgerNumber := int32(1) + channelAccount := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + _, err := chAccModel.Lock(ctx, chAccModel.DBConnectionPool, channelAccount.PublicKey, currLedgerNumber, currLedgerNumber+10) + require.NoError(t, err) + + _, err = chAccModel.GetAndLock(ctx, channelAccount.PublicKey, int(currLedgerNumber), int(currLedgerNumber+5)) + require.ErrorContains(t, err, fmt.Sprintf("cannot retrieve account %s", channelAccount.PublicKey)) + + DeleteAllFromChannelAccounts(t, ctx, chAccModel.DBConnectionPool) + }) + + t.Run("get and lock an available account", func(t *testing.T) { + currLedgerNumber := int32(1) + lockToLedgerNumber := currLedgerNumber + 10 + channelAccount := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + + updatedAccount, err := chAccModel.GetAndLock(ctx, channelAccount.PublicKey, int(currLedgerNumber), int(lockToLedgerNumber)) + require.NoError(t, err) + assert.Equal(t, updatedAccount.LockedUntilLedgerNumber.Int32, int32(lockToLedgerNumber)) + + DeleteAllFromChannelAccounts(t, ctx, chAccModel.DBConnectionPool) + }) +} diff --git a/internal/transactionsubmission/store/channel_transaction_bundle.go b/internal/transactionsubmission/store/channel_transaction_bundle.go new file mode 100644 index 000000000..fbb439f86 --- /dev/null +++ b/internal/transactionsubmission/store/channel_transaction_bundle.go @@ -0,0 +1,132 @@ +package store + +import ( + "context" + "fmt" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +var ErrInsuficientChannelAccounts = fmt.Errorf("there are no channel accounts available to process transactions") + +// ChannelTransactionBundle is an abstraction that aggregates a bundle of a ChannelAccount and a Transaction. It is used +// to prepare the resources for the workers, locking both the Transaction (the job) and the ChannelAccount (the +// resource), and then updating the lock according with the parameters provided. +type ChannelTransactionBundle struct { + // ChannelAccount is the resource needed to process the Transaction. + ChannelAccount ChannelAccount `db:"channel_account"` + // Transaction is the job that would be handled by the worker. + Transaction Transaction `db:"transaction"` + // LockedUntilLedgerNumber is the ledger number until which both the transaction and channel account are locked. + LockedUntilLedgerNumber int `db:"locked_until_ledger_number"` +} + +type ChannelTransactionBundleModel struct { + dbConnectionPool db.DBConnectionPool + channelAccountModel *ChannelAccountModel + transactionModel *TransactionModel +} + +func NewChannelTransactionBundleModel(dbConnectionPool db.DBConnectionPool) (*ChannelTransactionBundleModel, error) { + if dbConnectionPool == nil { + return nil, fmt.Errorf("dbConnectionPool cannot be nil") + } + + return &ChannelTransactionBundleModel{ + dbConnectionPool: dbConnectionPool, + channelAccountModel: &ChannelAccountModel{DBConnectionPool: dbConnectionPool}, + transactionModel: NewTransactionModel(dbConnectionPool), + }, nil +} + +// LoadAndLockTuples loads a slice of ChannelTransactionBundle from the database, and locks them until the given ledger +// number, up to the amount of transactions specified by the {limit} parameter. It returns the +// ErrInsuficientChannelAccounts error if there are transactions to process but no channel accounts available. +func (m *ChannelTransactionBundleModel) LoadAndLockTuples(ctx context.Context, currentLedgerNumber, lockToLedgerNumber, limit int) ([]*ChannelTransactionBundle, error) { + if limit < 1 { + return nil, fmt.Errorf("limit must be greater than 0") + } + + if lockToLedgerNumber <= currentLedgerNumber { + return nil, fmt.Errorf("lockToLedgerNumber must be greater than currentLedgerNumber") + } + + return db.RunInTransactionWithResult(ctx, m.dbConnectionPool, nil, func(dbTx db.DBTransaction) ([]*ChannelTransactionBundle, error) { + // STEP 1: get transactions available to be processed: + q := fmt.Sprintf(` + SELECT + * + FROM + submitter_transactions + WHERE + %s + AND synced_at IS NULL + AND status = ANY($1) + ORDER BY + updated_at ASC + LIMIT $2 + FOR UPDATE SKIP LOCKED + `, m.transactionModel.queryFilterForLockedState(false, int32(currentLedgerNumber)), + ) + var unlockedTransactions []Transaction + allowedTxStatuses := []TransactionStatus{TransactionStatusPending, TransactionStatusProcessing} + err := dbTx.SelectContext(ctx, &unlockedTransactions, q, pq.Array(allowedTxStatuses), limit) + if err != nil { + return nil, fmt.Errorf("fetching unlocked transactions: %w", err) + } + if len(unlockedTransactions) == 0 { + return nil, nil + } + + // STEP 2: get channel accounts available to process the transactions: + q = fmt.Sprintf(` + SELECT + * + FROM + channel_accounts + WHERE + %s + ORDER BY + updated_at ASC + LIMIT $1 + FOR UPDATE SKIP LOCKED + `, m.channelAccountModel.queryFilterForLockedState(false, int32(currentLedgerNumber)), + ) + var unlockedChannelAccounts []ChannelAccount + err = dbTx.SelectContext(ctx, &unlockedChannelAccounts, q, len(unlockedTransactions)) + if err != nil { + return nil, fmt.Errorf("calculating amount ov available channel accounts: %w", err) + } + if len(unlockedChannelAccounts) == 0 { + return nil, ErrInsuficientChannelAccounts + } + + // STEP 3: lock channel accounts and transactions, and build the bundle slice: + bundleLen := len(unlockedChannelAccounts) + bundles := make([]*ChannelTransactionBundle, bundleLen) + for i := 0; i < bundleLen; i++ { + chAcc := &unlockedChannelAccounts[i] + var lockedChAcc *ChannelAccount + lockedChAcc, err = m.channelAccountModel.Lock(ctx, dbTx, chAcc.PublicKey, int32(currentLedgerNumber), int32(lockToLedgerNumber)) + if err != nil { + return nil, fmt.Errorf("locking channel account %q: %w", chAcc.PublicKey, err) + } + + tx := &unlockedTransactions[i] + var lockedTx *Transaction + lockedTx, err = m.transactionModel.Lock(ctx, dbTx, tx.ID, int32(currentLedgerNumber), int32(lockToLedgerNumber)) + if err != nil { + return nil, fmt.Errorf("locking transaction %q: %w", tx.ID, err) + } + + bundles[i] = &ChannelTransactionBundle{ + ChannelAccount: *lockedChAcc, + Transaction: *lockedTx, + LockedUntilLedgerNumber: lockToLedgerNumber, + } + } + + return bundles, nil + }) +} diff --git a/internal/transactionsubmission/store/channel_transaction_bundle_test.go b/internal/transactionsubmission/store/channel_transaction_bundle_test.go new file mode 100644 index 000000000..baa71247d --- /dev/null +++ b/internal/transactionsubmission/store/channel_transaction_bundle_test.go @@ -0,0 +1,220 @@ +package store + +import ( + "context" + "fmt" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + sdpUtils "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/require" +) + +func Test_NewChannelTransactionBundleModel(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + testCases := []struct { + name string + dbConnection db.DBConnectionPool + expectedError error + expectedModel *ChannelTransactionBundleModel + }{ + { + name: "returns an error if dbConnectionPool is nil", + dbConnection: nil, + expectedError: fmt.Errorf("dbConnectionPool cannot be nil"), + expectedModel: nil, + }, + { + name: "πŸŽ‰ successfully returns a model if dbConnectionPool is not nil", + dbConnection: dbConnectionPool, + expectedError: nil, + expectedModel: &ChannelTransactionBundleModel{ + dbConnectionPool: dbConnectionPool, + channelAccountModel: &ChannelAccountModel{DBConnectionPool: dbConnectionPool}, + transactionModel: NewTransactionModel(dbConnectionPool), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualModel, actualError := NewChannelTransactionBundleModel(tc.dbConnection) + require.Equal(t, tc.expectedError, actualError) + require.Equal(t, tc.expectedModel, actualModel) + }) + } +} + +func Test_ChannelTransactionBundleModel_LoadAndLockTuples(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + const currentLedgerNumber = 100 + + chAccTupleModel, err := NewChannelTransactionBundleModel(dbConnectionPool) + require.NoError(t, err) + txModel := NewTransactionModel(dbConnectionPool) + chAccModel := ChannelAccountModel{dbConnectionPool} + + testCases := []struct { + name string + limit int + lockToLedgerNumber int + numberOfChannelAccountsLocked int + numberOfChannelAccountsUnlocked int + numberOfTransactionsLocked int + numberOfTransactionsUnlocked int + expectedError error + }{ + { + name: "returns an error if limit<1", + limit: 0, + expectedError: fmt.Errorf("limit must be greater than 0"), + }, + { + name: "returns an error if lockToLedgerNumber<=currentLedgerNumber", + limit: 100, + lockToLedgerNumber: currentLedgerNumber, + expectedError: fmt.Errorf("lockToLedgerNumber must be greater than currentLedgerNumber"), + }, + { + name: "returns nil len(transactions) == 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 1, + }, + { + name: "returns nil if len(unlockedTransactions) == 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 1, + numberOfTransactionsLocked: 10, + }, + { + name: "returns nil if len(unlockedTransactions) == 0 && len(unlockedChannelAccounts) > 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 1, + numberOfTransactionsLocked: 10, + numberOfChannelAccountsUnlocked: 10, + }, + { + name: "returns an error if len(unlockedTransactions) > 0 && len(unlockedChannelAccounts) == 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 1, + numberOfTransactionsUnlocked: 10, + expectedError: fmt.Errorf("running atomic function in RunInTransactionWithResult: %w", ErrInsuficientChannelAccounts), + }, + { + name: "πŸŽ‰ successfully returns chTxBundles if limit == len(unlockedTransactions) == len(unlockedChannelAccounts) > 0", + limit: 10, + lockToLedgerNumber: currentLedgerNumber + 10, + numberOfTransactionsUnlocked: 10, + numberOfChannelAccountsUnlocked: 10, + }, + { + name: "πŸŽ‰ successfully returns chTxBundles if limit < len(unlockedTransactions) == len(unlockedChannelAccounts) > 0", + limit: 5, + lockToLedgerNumber: currentLedgerNumber + 10, + numberOfTransactionsUnlocked: 10, + numberOfChannelAccountsUnlocked: 10, + }, + { + name: "πŸŽ‰ successfully returns chTxBundles if limit > len(unlockedTransactions) == len(unlockedChannelAccounts) > 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 10, + numberOfTransactionsUnlocked: 10, + numberOfChannelAccountsUnlocked: 10, + }, + { + name: "πŸŽ‰ successfully returns chTxBundles if limit > len(unlockedTransactions) > len(unlockedChannelAccounts) > 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 10, + numberOfTransactionsUnlocked: 20, + numberOfChannelAccountsUnlocked: 10, + }, + { + name: "πŸŽ‰ successfully returns chTxBundles if limit > len(unlockedChannelAccounts) > len(unlockedTransactions) > 0", + limit: 100, + lockToLedgerNumber: currentLedgerNumber + 10, + numberOfTransactionsUnlocked: 10, + numberOfChannelAccountsUnlocked: 20, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // ChannelAccounts(LOCKED) + lockedChAccounts := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, tc.numberOfChannelAccountsLocked) + for _, chAcc := range lockedChAccounts { + _, err = chAccModel.Lock(ctx, dbConnectionPool, chAcc.PublicKey, int32(currentLedgerNumber*2), int32(tc.lockToLedgerNumber)) + require.NoError(t, err) + } + + // ChannelAccounts(UNLOCKED) + unlockedChAccounts := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, tc.numberOfChannelAccountsUnlocked) + + // Transactions(LOCKED) + lockedTransactions := CreateTransactionFixtures(t, ctx, dbConnectionPool, tc.numberOfTransactionsLocked, "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", "", TransactionStatusPending, 1) + for _, tx := range lockedTransactions { + _, err = txModel.Lock(ctx, dbConnectionPool, tx.ID, int32(currentLedgerNumber*2), int32(tc.lockToLedgerNumber)) + require.NoError(t, err) + } + + // Transactions(UNLOCKED) + unlockedTransactions := CreateTransactionFixtures(t, ctx, dbConnectionPool, tc.numberOfTransactionsUnlocked, "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", "", TransactionStatusPending, 1) + + chTxBundles, err := chAccTupleModel.LoadAndLockTuples(ctx, currentLedgerNumber, tc.lockToLedgerNumber, tc.limit) + if tc.expectedError != nil { + require.Error(t, err) + require.Equal(t, tc.expectedError, err) + require.Empty(t, chTxBundles) + } else { + require.NoError(t, err) + minLength := tc.limit + if tc.numberOfChannelAccountsUnlocked < minLength { + minLength = tc.numberOfChannelAccountsUnlocked + } + if tc.numberOfTransactionsUnlocked < minLength { + minLength = tc.numberOfTransactionsUnlocked + } + require.Len(t, chTxBundles, minLength) + + if len(chTxBundles) == 0 { + return + } + + initiallyUnlockedChAccIDs := sdpUtils.MapSlice(unlockedChAccounts, func(chAcc *ChannelAccount) string { return chAcc.PublicKey }) + gotChAccIDs := sdpUtils.MapSlice(chTxBundles, func(chTxBundle *ChannelTransactionBundle) string { return chTxBundle.ChannelAccount.PublicKey }) + require.Subset(t, initiallyUnlockedChAccIDs, gotChAccIDs) + + initiallyUnlockedTxIDs := sdpUtils.MapSlice(unlockedTransactions, func(tx *Transaction) string { return tx.ID }) + gotTxIDs := sdpUtils.MapSlice(chTxBundles, func(chTxBundle *ChannelTransactionBundle) string { return chTxBundle.Transaction.ID }) + require.Subset(t, initiallyUnlockedTxIDs, gotTxIDs) + + // verify if the channel accounts are properly locked in the DB + var count int + q := fmt.Sprintf(`SELECT COUNT(*) FROM channel_accounts WHERE %s`, chAccModel.queryFilterForLockedState(true, currentLedgerNumber)) + err = dbConnectionPool.GetContext(ctx, &count, q) + require.NoError(t, err) + require.Equal(t, tc.numberOfChannelAccountsLocked+len(chTxBundles), count) + + // verify if the transactions are properly locked in the DB + q = fmt.Sprintf(`SELECT COUNT(*) FROM submitter_transactions WHERE %s`, txModel.queryFilterForLockedState(true, currentLedgerNumber)) + err = dbConnectionPool.GetContext(ctx, &count, q) + require.NoError(t, err) + require.Equal(t, tc.numberOfTransactionsLocked+len(chTxBundles), count) + } + + DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + }) + } +} diff --git a/internal/transactionsubmission/store/fixtures.go b/internal/transactionsubmission/store/fixtures.go new file mode 100644 index 000000000..310516faf --- /dev/null +++ b/internal/transactionsubmission/store/fixtures.go @@ -0,0 +1,102 @@ +package store + +import ( + "context" + "crypto/rand" + "math/big" + "testing" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "github.com/stellar/go/keypair" + "github.com/stretchr/testify/require" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +// CreateTransactionFixtures creates count number submitter transactions +func CreateTransactionFixtures(t *testing.T, + ctx context.Context, + sqlExec db.SQLExecuter, + count int, + code, issuer, destination string, + status TransactionStatus, + amount float64, +) []*Transaction { + var txs []*Transaction + for i := 0; i < count; i++ { + tx := CreateTransactionFixture(t, ctx, sqlExec, uuid.NewString(), code, issuer, destination, status, amount) + txs = append(txs, tx) + } + + return txs +} + +// CreateTransactionFixture creates a submitter transaction in the database +func CreateTransactionFixture( + t *testing.T, + ctx context.Context, + sqlExec db.SQLExecuter, + externalID, assetCode, assetIssuer, destinationAddress string, + status TransactionStatus, + amount float64, +) *Transaction { + if assetIssuer == "" { + assetIssuer = keypair.MustRandom().Address() + } + + if destinationAddress == "" { + destinationAddress = keypair.MustRandom().Address() + } + + completedAt := pq.NullTime{} + if status == TransactionStatusSuccess || status == TransactionStatusError { + timeElapsed, _ := rand.Int(rand.Reader, big.NewInt(time.Now().Unix())) + randomCompletedAt := time.Unix(timeElapsed.Int64(), 0) + completedAt = pq.NullTime{Time: randomCompletedAt, Valid: true} + } + + const query = ` + INSERT INTO submitter_transactions + (external_id, status, asset_code, asset_issuer, amount, destination, completed_at, started_at) + VALUES + ($1, $2, $3, $4, $5, $6, $7, NOW()) + RETURNING + * + ` + + tx := Transaction{} + err := sqlExec.GetContext(ctx, &tx, query, externalID, string(status), assetCode, assetIssuer, amount, destinationAddress, completedAt) + require.NoError(t, err) + + return &tx +} + +// DeleteAllTransactionFixtures deletes all submitter transactions in the database +func DeleteAllTransactionFixtures(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + const query = "DELETE FROM submitter_transactions" + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} + +// CreateChannelAccountFixtures craetes count number of channel accounts +func CreateChannelAccountFixtures(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool, count int) []*ChannelAccount { + caModel := ChannelAccountModel{DBConnectionPool: dbConnectionPool} + for i := 0; i < count; i++ { + generatedKeypair := keypair.MustRandom() + err := caModel.Insert(ctx, dbConnectionPool, generatedKeypair.Address(), generatedKeypair.Seed()) + require.NoError(t, err) + } + + channelAccounts, err := caModel.GetAll(ctx, dbConnectionPool, 0, count) + require.NoError(t, err) + + return channelAccounts +} + +func DeleteAllFromChannelAccounts(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter) { + query := `DELETE FROM channel_accounts` + _, err := sqlExec.ExecContext(ctx, query) + require.NoError(t, err) +} diff --git a/internal/transactionsubmission/store/fixtures_test.go b/internal/transactionsubmission/store/fixtures_test.go new file mode 100644 index 000000000..983b660b3 --- /dev/null +++ b/internal/transactionsubmission/store/fixtures_test.go @@ -0,0 +1,120 @@ +package store + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Fixtures_CreateTransactionFixture(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + tx := Transaction{ + AssetCode: "USDC", + AssetIssuer: "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", + Amount: 1, + } + + t.Run("create transaction with pending status", func(t *testing.T) { + tx.ExternalID = uuid.NewString() + createdTx := CreateTransactionFixture( + t, + ctx, + dbConnectionPool, + tx.ExternalID, tx.AssetCode, + tx.AssetIssuer, tx.Destination, + TransactionStatusPending, tx.Amount, + ) + assert.Equal(t, tx.AssetCode, createdTx.AssetCode) + assert.Equal(t, tx.AssetIssuer, createdTx.AssetIssuer) + assert.Equal(t, tx.ExternalID, createdTx.ExternalID) + assert.Equal(t, tx.Amount, createdTx.Amount) + assert.Empty(t, createdTx.CompletedAt) + }) + + t.Run("create transaction with successful status", func(t *testing.T) { + tx.ExternalID = uuid.NewString() + createdTx := CreateTransactionFixture( + t, + ctx, + dbConnectionPool, + tx.ExternalID, tx.AssetCode, + tx.AssetIssuer, tx.Destination, + TransactionStatusSuccess, tx.Amount, + ) + assert.Equal(t, tx.AssetCode, createdTx.AssetCode) + assert.Equal(t, tx.AssetIssuer, createdTx.AssetIssuer) + assert.Equal(t, tx.ExternalID, createdTx.ExternalID) + assert.Equal(t, tx.Amount, createdTx.Amount) + assert.False(t, createdTx.CompletedAt.IsZero()) + }) +} + +func Test_Fixtures_CreateAndDeleteAllTransactionFixtures(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + tx := Transaction{ + ExternalID: "external-id-1", + AssetCode: "USDC", + AssetIssuer: "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", + Amount: 1, + } + + t.Run("create and delete transactions", func(t *testing.T) { + txCount := 5 + createdTxs := CreateTransactionFixtures( + t, + ctx, + dbConnectionPool, + txCount, tx.AssetCode, + tx.AssetIssuer, tx.Destination, + TransactionStatusPending, tx.Amount, + ) + + assert.Len(t, createdTxs, txCount) + var createdTxIDs []string + for _, createdTx := range createdTxs { + createdTxIDs = append(createdTxIDs, createdTx.ID) + } + + DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + txModel := TransactionModel{DBConnectionPool: dbConnectionPool} + + for _, id := range createdTxIDs { + tx, err := txModel.Get(ctx, id) + require.EqualError(t, err, ErrRecordNotFound.Error()) + assert.Nil(t, tx) + } + }) +} + +func Test_Fixtures_CreateChannelAccountsOnChainFixtures(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + ctx := context.Background() + channelAccountsCount := 5 + channelAccounts := CreateChannelAccountFixtures(t, ctx, dbConnectionPool, channelAccountsCount) + assert.Len(t, channelAccounts, channelAccountsCount) +} diff --git a/internal/transactionsubmission/store/mocks/channel_account_store.go b/internal/transactionsubmission/store/mocks/channel_account_store.go new file mode 100644 index 000000000..2cc7be192 --- /dev/null +++ b/internal/transactionsubmission/store/mocks/channel_account_store.go @@ -0,0 +1,296 @@ +// Code generated by mockery v2.27.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + db "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + mock "github.com/stretchr/testify/mock" + + store "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" +) + +// MockChannelAccountStore is an autogenerated mock type for the ChannelAccountStore type +type MockChannelAccountStore struct { + mock.Mock +} + +// BatchInsert provides a mock function with given fields: ctx, sqlExec, channelAccounts +func (_m *MockChannelAccountStore) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*store.ChannelAccount) error { + ret := _m.Called(ctx, sqlExec, channelAccounts) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, []*store.ChannelAccount) error); ok { + r0 = rf(ctx, sqlExec, channelAccounts) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// BatchInsertAndLock provides a mock function with given fields: ctx, channelAccounts, currentLedger, nextLedgerLock +func (_m *MockChannelAccountStore) BatchInsertAndLock(ctx context.Context, channelAccounts []*store.ChannelAccount, currentLedger int, nextLedgerLock int) error { + ret := _m.Called(ctx, channelAccounts, currentLedger, nextLedgerLock) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []*store.ChannelAccount, int, int) error); ok { + r0 = rf(ctx, channelAccounts, currentLedger, nextLedgerLock) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Count provides a mock function with given fields: ctx +func (_m *MockChannelAccountStore) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (int, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Delete provides a mock function with given fields: ctx, sqlExec, publicKey +func (_m *MockChannelAccountStore) Delete(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) error { + ret := _m.Called(ctx, sqlExec, publicKey) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) error); ok { + r0 = rf(ctx, sqlExec, publicKey) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteIfLockedUntil provides a mock function with given fields: ctx, publicKey, lockedUntilLedgerNumber +func (_m *MockChannelAccountStore) DeleteIfLockedUntil(ctx context.Context, publicKey string, lockedUntilLedgerNumber int) error { + ret := _m.Called(ctx, publicKey, lockedUntilLedgerNumber) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, int) error); ok { + r0 = rf(ctx, publicKey, lockedUntilLedgerNumber) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Get provides a mock function with given fields: ctx, sqlExec, publicKey, currentLedgerNumber +func (_m *MockChannelAccountStore) Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, currentLedgerNumber int) (*store.ChannelAccount, error) { + ret := _m.Called(ctx, sqlExec, publicKey, currentLedgerNumber) + + var r0 *store.ChannelAccount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, int) (*store.ChannelAccount, error)); ok { + return rf(ctx, sqlExec, publicKey, currentLedgerNumber) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, int) *store.ChannelAccount); ok { + r0 = rf(ctx, sqlExec, publicKey, currentLedgerNumber) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.ChannelAccount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, string, int) error); ok { + r1 = rf(ctx, sqlExec, publicKey, currentLedgerNumber) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAll provides a mock function with given fields: ctx, sqlExec, currentLedger, limit +func (_m *MockChannelAccountStore) GetAll(ctx context.Context, sqlExec db.SQLExecuter, currentLedger int, limit int) ([]*store.ChannelAccount, error) { + ret := _m.Called(ctx, sqlExec, currentLedger, limit) + + var r0 []*store.ChannelAccount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, int, int) ([]*store.ChannelAccount, error)); ok { + return rf(ctx, sqlExec, currentLedger, limit) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, int, int) []*store.ChannelAccount); ok { + r0 = rf(ctx, sqlExec, currentLedger, limit) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.ChannelAccount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, int, int) error); ok { + r1 = rf(ctx, sqlExec, currentLedger, limit) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAndLock provides a mock function with given fields: ctx, publicKey, currentLedger, nextLedgerLock +func (_m *MockChannelAccountStore) GetAndLock(ctx context.Context, publicKey string, currentLedger int, nextLedgerLock int) (*store.ChannelAccount, error) { + ret := _m.Called(ctx, publicKey, currentLedger, nextLedgerLock) + + var r0 *store.ChannelAccount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int, int) (*store.ChannelAccount, error)); ok { + return rf(ctx, publicKey, currentLedger, nextLedgerLock) + } + if rf, ok := ret.Get(0).(func(context.Context, string, int, int) *store.ChannelAccount); ok { + r0 = rf(ctx, publicKey, currentLedger, nextLedgerLock) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.ChannelAccount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, int, int) error); ok { + r1 = rf(ctx, publicKey, currentLedger, nextLedgerLock) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAndLockAll provides a mock function with given fields: ctx, currentLedger, nextLedgerLock, limit +func (_m *MockChannelAccountStore) GetAndLockAll(ctx context.Context, currentLedger int, nextLedgerLock int, limit int) ([]*store.ChannelAccount, error) { + ret := _m.Called(ctx, currentLedger, nextLedgerLock, limit) + + var r0 []*store.ChannelAccount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int, int, int) ([]*store.ChannelAccount, error)); ok { + return rf(ctx, currentLedger, nextLedgerLock, limit) + } + if rf, ok := ret.Get(0).(func(context.Context, int, int, int) []*store.ChannelAccount); ok { + r0 = rf(ctx, currentLedger, nextLedgerLock, limit) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.ChannelAccount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int, int, int) error); ok { + r1 = rf(ctx, currentLedger, nextLedgerLock, limit) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Insert provides a mock function with given fields: ctx, sqlExec, publicKey, privateKey +func (_m *MockChannelAccountStore) Insert(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, privateKey string) error { + ret := _m.Called(ctx, sqlExec, publicKey, privateKey) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, string) error); ok { + r0 = rf(ctx, sqlExec, publicKey, privateKey) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// InsertAndLock provides a mock function with given fields: ctx, publicKey, privateKey, currentLedger, nextLedgerLock +func (_m *MockChannelAccountStore) InsertAndLock(ctx context.Context, publicKey string, privateKey string, currentLedger int, nextLedgerLock int) error { + ret := _m.Called(ctx, publicKey, privateKey, currentLedger, nextLedgerLock) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, int, int) error); ok { + r0 = rf(ctx, publicKey, privateKey, currentLedger, nextLedgerLock) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Lock provides a mock function with given fields: ctx, sqlExec, publicKey, currentLedger, nextLedgerLock +func (_m *MockChannelAccountStore) Lock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, currentLedger int32, nextLedgerLock int32) (*store.ChannelAccount, error) { + ret := _m.Called(ctx, sqlExec, publicKey, currentLedger, nextLedgerLock) + + var r0 *store.ChannelAccount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, int32, int32) (*store.ChannelAccount, error)); ok { + return rf(ctx, sqlExec, publicKey, currentLedger, nextLedgerLock) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, int32, int32) *store.ChannelAccount); ok { + r0 = rf(ctx, sqlExec, publicKey, currentLedger, nextLedgerLock) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.ChannelAccount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, string, int32, int32) error); ok { + r1 = rf(ctx, sqlExec, publicKey, currentLedger, nextLedgerLock) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Unlock provides a mock function with given fields: ctx, sqlExec, publicKey +func (_m *MockChannelAccountStore) Unlock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*store.ChannelAccount, error) { + ret := _m.Called(ctx, sqlExec, publicKey) + + var r0 *store.ChannelAccount + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) (*store.ChannelAccount, error)); ok { + return rf(ctx, sqlExec, publicKey) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) *store.ChannelAccount); ok { + r0 = rf(ctx, sqlExec, publicKey) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.ChannelAccount) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, string) error); ok { + r1 = rf(ctx, sqlExec, publicKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewMockChannelAccountStore interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockChannelAccountStore creates a new instance of MockChannelAccountStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockChannelAccountStore(t mockConstructorTestingTNewMockChannelAccountStore) *MockChannelAccountStore { + mock := &MockChannelAccountStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/transactionsubmission/store/mocks/transaction_store.go b/internal/transactionsubmission/store/mocks/transaction_store.go new file mode 100644 index 000000000..e2c7d4575 --- /dev/null +++ b/internal/transactionsubmission/store/mocks/transaction_store.go @@ -0,0 +1,358 @@ +// Code generated by mockery v2.27.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + db "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + mock "github.com/stretchr/testify/mock" + + store "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" +) + +// MockTransactionStore is an autogenerated mock type for the TransactionStore type +type MockTransactionStore struct { + mock.Mock +} + +// BulkInsert provides a mock function with given fields: ctx, sqlExec, transactions +func (_m *MockTransactionStore) BulkInsert(ctx context.Context, sqlExec db.SQLExecuter, transactions []store.Transaction) ([]store.Transaction, error) { + ret := _m.Called(ctx, sqlExec, transactions) + + var r0 []store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, []store.Transaction) ([]store.Transaction, error)); ok { + return rf(ctx, sqlExec, transactions) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, []store.Transaction) []store.Transaction); ok { + r0 = rf(ctx, sqlExec, transactions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, []store.Transaction) error); ok { + r1 = rf(ctx, sqlExec, transactions) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Get provides a mock function with given fields: ctx, txID +func (_m *MockTransactionStore) Get(ctx context.Context, txID string) (*store.Transaction, error) { + ret := _m.Called(ctx, txID) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*store.Transaction, error)); ok { + return rf(ctx, txID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *store.Transaction); ok { + r0 = rf(ctx, txID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, txID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAllByPaymentIDs provides a mock function with given fields: ctx, paymentIDs +func (_m *MockTransactionStore) GetAllByPaymentIDs(ctx context.Context, paymentIDs []string) ([]*store.Transaction, error) { + ret := _m.Called(ctx, paymentIDs) + + var r0 []*store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []string) ([]*store.Transaction, error)); ok { + return rf(ctx, paymentIDs) + } + if rf, ok := ret.Get(0).(func(context.Context, []string) []*store.Transaction); ok { + r0 = rf(ctx, paymentIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, paymentIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTransactionBatchForUpdate provides a mock function with given fields: ctx, dbTx, batchSize +func (_m *MockTransactionStore) GetTransactionBatchForUpdate(ctx context.Context, dbTx db.DBTransaction, batchSize int) ([]*store.Transaction, error) { + ret := _m.Called(ctx, dbTx, batchSize) + + var r0 []*store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.DBTransaction, int) ([]*store.Transaction, error)); ok { + return rf(ctx, dbTx, batchSize) + } + if rf, ok := ret.Get(0).(func(context.Context, db.DBTransaction, int) []*store.Transaction); ok { + r0 = rf(ctx, dbTx, batchSize) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.DBTransaction, int) error); ok { + r1 = rf(ctx, dbTx, batchSize) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Insert provides a mock function with given fields: ctx, tx +func (_m *MockTransactionStore) Insert(ctx context.Context, tx store.Transaction) (*store.Transaction, error) { + ret := _m.Called(ctx, tx) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, store.Transaction) (*store.Transaction, error)); ok { + return rf(ctx, tx) + } + if rf, ok := ret.Get(0).(func(context.Context, store.Transaction) *store.Transaction); ok { + r0 = rf(ctx, tx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, store.Transaction) error); ok { + r1 = rf(ctx, tx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Lock provides a mock function with given fields: ctx, sqlExec, transactionID, currentLedger, nextLedgerLock +func (_m *MockTransactionStore) Lock(ctx context.Context, sqlExec db.SQLExecuter, transactionID string, currentLedger int32, nextLedgerLock int32) (*store.Transaction, error) { + ret := _m.Called(ctx, sqlExec, transactionID, currentLedger, nextLedgerLock) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, int32, int32) (*store.Transaction, error)); ok { + return rf(ctx, sqlExec, transactionID, currentLedger, nextLedgerLock) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string, int32, int32) *store.Transaction); ok { + r0 = rf(ctx, sqlExec, transactionID, currentLedger, nextLedgerLock) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, string, int32, int32) error); ok { + r1 = rf(ctx, sqlExec, transactionID, currentLedger, nextLedgerLock) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PrepareTransactionForReprocessing provides a mock function with given fields: ctx, sqlExec, transactionID +func (_m *MockTransactionStore) PrepareTransactionForReprocessing(ctx context.Context, sqlExec db.SQLExecuter, transactionID string) (*store.Transaction, error) { + ret := _m.Called(ctx, sqlExec, transactionID) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) (*store.Transaction, error)); ok { + return rf(ctx, sqlExec, transactionID) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) *store.Transaction); ok { + r0 = rf(ctx, sqlExec, transactionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, string) error); ok { + r1 = rf(ctx, sqlExec, transactionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Unlock provides a mock function with given fields: ctx, sqlExec, publicKey +func (_m *MockTransactionStore) Unlock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*store.Transaction, error) { + ret := _m.Called(ctx, sqlExec, publicKey) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) (*store.Transaction, error)); ok { + return rf(ctx, sqlExec, publicKey) + } + if rf, ok := ret.Get(0).(func(context.Context, db.SQLExecuter, string) *store.Transaction); ok { + r0 = rf(ctx, sqlExec, publicKey) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, db.SQLExecuter, string) error); ok { + r1 = rf(ctx, sqlExec, publicKey) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateStatusToError provides a mock function with given fields: ctx, tx, message +func (_m *MockTransactionStore) UpdateStatusToError(ctx context.Context, tx store.Transaction, message string) (*store.Transaction, error) { + ret := _m.Called(ctx, tx, message) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, store.Transaction, string) (*store.Transaction, error)); ok { + return rf(ctx, tx, message) + } + if rf, ok := ret.Get(0).(func(context.Context, store.Transaction, string) *store.Transaction); ok { + r0 = rf(ctx, tx, message) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, store.Transaction, string) error); ok { + r1 = rf(ctx, tx, message) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateStatusToSuccess provides a mock function with given fields: ctx, tx +func (_m *MockTransactionStore) UpdateStatusToSuccess(ctx context.Context, tx store.Transaction) (*store.Transaction, error) { + ret := _m.Called(ctx, tx) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, store.Transaction) (*store.Transaction, error)); ok { + return rf(ctx, tx) + } + if rf, ok := ret.Get(0).(func(context.Context, store.Transaction) *store.Transaction); ok { + r0 = rf(ctx, tx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, store.Transaction) error); ok { + r1 = rf(ctx, tx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateStellarTransactionHashAndXDRSent provides a mock function with given fields: ctx, txID, txHash, txXDRSent +func (_m *MockTransactionStore) UpdateStellarTransactionHashAndXDRSent(ctx context.Context, txID string, txHash string, txXDRSent string) (*store.Transaction, error) { + ret := _m.Called(ctx, txID, txHash, txXDRSent) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*store.Transaction, error)); ok { + return rf(ctx, txID, txHash, txXDRSent) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *store.Transaction); ok { + r0 = rf(ctx, txID, txHash, txXDRSent) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, txID, txHash, txXDRSent) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateStellarTransactionXDRReceived provides a mock function with given fields: ctx, txID, xdrReceived +func (_m *MockTransactionStore) UpdateStellarTransactionXDRReceived(ctx context.Context, txID string, xdrReceived string) (*store.Transaction, error) { + ret := _m.Called(ctx, txID, xdrReceived) + + var r0 *store.Transaction + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*store.Transaction, error)); ok { + return rf(ctx, txID, xdrReceived) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) *store.Transaction); ok { + r0 = rf(ctx, txID, xdrReceived) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.Transaction) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, txID, xdrReceived) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateSyncedTransactions provides a mock function with given fields: ctx, dbTx, txIDs +func (_m *MockTransactionStore) UpdateSyncedTransactions(ctx context.Context, dbTx db.DBTransaction, txIDs []string) error { + ret := _m.Called(ctx, dbTx, txIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, db.DBTransaction, []string) error); ok { + r0 = rf(ctx, dbTx, txIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewMockTransactionStore interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockTransactionStore creates a new instance of MockTransactionStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockTransactionStore(t mockConstructorTestingTNewMockTransactionStore) *MockTransactionStore { + mock := &MockTransactionStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/transactionsubmission/store/store.go b/internal/transactionsubmission/store/store.go new file mode 100644 index 000000000..b86bee007 --- /dev/null +++ b/internal/transactionsubmission/store/store.go @@ -0,0 +1,45 @@ +package store + +import ( + "context" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +//go:generate mockery --name=ChannelAccountStore --case=underscore --structname=MockChannelAccountStore +type ChannelAccountStore interface { + Delete(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (err error) + DeleteIfLockedUntil(ctx context.Context, publicKey string, lockedUntilLedgerNumber int) (err error) + Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, currentLedgerNumber int) (ca *ChannelAccount, err error) + GetAndLock(ctx context.Context, publicKey string, currentLedger, nextLedgerLock int) (*ChannelAccount, error) + Count(ctx context.Context) (count int, err error) + GetAll(ctx context.Context, sqlExec db.SQLExecuter, currentLedger, limit int) ([]*ChannelAccount, error) + GetAndLockAll(ctx context.Context, currentLedger, nextLedgerLock, limit int) ([]*ChannelAccount, error) + Insert(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, privateKey string) error + InsertAndLock(ctx context.Context, publicKey string, privateKey string, currentLedger, nextLedgerLock int) error + BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error + BatchInsertAndLock(ctx context.Context, channelAccounts []*ChannelAccount, currentLedger, nextLedgerLock int) error + // Lock management: + Lock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string, currentLedger, nextLedgerLock int32) (*ChannelAccount, error) + Unlock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*ChannelAccount, error) +} + +//go:generate mockery --name=TransactionStore --case=underscore --structname=MockTransactionStore +type TransactionStore interface { + // CRUD: + Insert(ctx context.Context, tx Transaction) (*Transaction, error) + BulkInsert(ctx context.Context, sqlExec db.SQLExecuter, transactions []Transaction) ([]Transaction, error) + Get(ctx context.Context, txID string) (tx *Transaction, err error) + GetAllByPaymentIDs(ctx context.Context, paymentIDs []string) (transactions []*Transaction, err error) + // Status & Lock management: + UpdateStatusToSuccess(ctx context.Context, tx Transaction) (updatedTx *Transaction, err error) + UpdateStatusToError(ctx context.Context, tx Transaction, message string) (updatedTx *Transaction, err error) + UpdateStellarTransactionXDRReceived(ctx context.Context, txID string, xdrReceived string) (*Transaction, error) + UpdateStellarTransactionHashAndXDRSent(ctx context.Context, txID string, txHash, txXDRSent string) (*Transaction, error) + Lock(ctx context.Context, sqlExec db.SQLExecuter, transactionID string, currentLedger, nextLedgerLock int32) (*Transaction, error) + Unlock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*Transaction, error) + // Queue management: + PrepareTransactionForReprocessing(ctx context.Context, sqlExec db.SQLExecuter, transactionID string) (*Transaction, error) + GetTransactionBatchForUpdate(ctx context.Context, dbTx db.DBTransaction, batchSize int) (transactions []*Transaction, err error) + UpdateSyncedTransactions(ctx context.Context, dbTx db.DBTransaction, txIDs []string) error +} diff --git a/internal/transactionsubmission/store/transaction_state_machine.go b/internal/transactionsubmission/store/transaction_state_machine.go new file mode 100644 index 000000000..dd036c3e3 --- /dev/null +++ b/internal/transactionsubmission/store/transaction_state_machine.go @@ -0,0 +1,54 @@ +package store + +import ( + "fmt" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "golang.org/x/exp/slices" +) + +type TransactionStatus string + +const ( + // TransactionStatusPending indicates that a transaction has been created and added to the queue. + TransactionStatusPending TransactionStatus = "PENDING" // TODO: rename to TransactionStatusQueued + // TransactionStatusProcessing indicates that a transaction has been read from the queue and is being processed. + TransactionStatusProcessing TransactionStatus = "PROCESSING" + // TransactionStatusSuccess indicates that the transaction was successfully sent and included in the ledger. + TransactionStatusSuccess TransactionStatus = "SUCCESS" + // TransactionStatusError indicates that there was an error when trying to send this transaction. + TransactionStatusError TransactionStatus = "ERROR" +) + +func (status TransactionStatus) All() []TransactionStatus { + return []TransactionStatus{TransactionStatusPending, TransactionStatusProcessing, TransactionStatusSuccess, TransactionStatusError} +} + +// Validate validates the disbursement status +func (status TransactionStatus) Validate() error { + if slices.Contains(TransactionStatus("").All(), status) { + return nil + } + return fmt.Errorf("invalid disbursement status: %s", status) +} + +// State will parse the TransactionState into a data.State. +func (status TransactionStatus) State() data.State { + return data.State(status) +} + +// CanTransitionTo verifies if the transition is allowed. +func (status TransactionStatus) CanTransitionTo(targetState TransactionStatus) error { + return tssTransactionStateMachineWithInitialState(status).TransitionTo(targetState.State()) +} + +// tssTransactionStateMachineWithInitialState returns a state machine for TSS transactions, initialized with the given state. +func tssTransactionStateMachineWithInitialState(initialState TransactionStatus) *data.StateMachine { + transitions := []data.StateTransition{ + {From: TransactionStatusPending.State(), To: TransactionStatusProcessing.State()}, // TSS loads the transaction from the DB for the first time. + {From: TransactionStatusProcessing.State(), To: TransactionStatusSuccess.State()}, // TSS receives a success response from Stellar Horizon. + {From: TransactionStatusProcessing.State(), To: TransactionStatusError.State()}, // TSS receives an error response from Stellar Horizon. + } + + return data.NewStateMachine(initialState.State(), transitions) +} diff --git a/internal/transactionsubmission/store/transaction_state_machine_test.go b/internal/transactionsubmission/store/transaction_state_machine_test.go new file mode 100644 index 000000000..f10c5c827 --- /dev/null +++ b/internal/transactionsubmission/store/transaction_state_machine_test.go @@ -0,0 +1,129 @@ +package store + +import ( + "fmt" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_TransactionStatus_All(t *testing.T) { + allStatuses := TransactionStatus("").All() + require.Len(t, allStatuses, 4) + require.Contains(t, allStatuses, TransactionStatusPending) + require.Contains(t, allStatuses, TransactionStatusProcessing) + require.Contains(t, allStatuses, TransactionStatusSuccess) + require.Contains(t, allStatuses, TransactionStatusError) +} + +func Test_TransactionStatus_Validate(t *testing.T) { + testCases := []struct { + name string + status TransactionStatus + wantError error + }{ + { + name: "valid status (PENDING)", + status: TransactionStatusPending, + }, + { + name: "valid status (PROCESSING)", + status: TransactionStatusProcessing, + }, + { + name: "valid status (SUCCESS)", + status: TransactionStatusSuccess, + }, + { + name: "valid status (ERROR)", + status: TransactionStatusError, + }, + { + name: "invalid status (UNKNOWN)", + status: TransactionStatus("UNKNOWN"), + wantError: fmt.Errorf("invalid disbursement status: UNKNOWN"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.status.Validate() + if tc.wantError == nil { + require.NoError(t, err) + } else { + require.Equal(t, tc.wantError, err) + } + }) + } +} + +func Test_TransactionStatus_State(t *testing.T) { + for _, status := range TransactionStatus("").All() { + t.Run(string(status), func(t *testing.T) { + require.Equal(t, data.State(status), status.State()) + }) + } +} + +func Test_TransactionStatus_CanTransitionTo(t *testing.T) { + type canTransitionTestCase struct { + name string + from TransactionStatus + to TransactionStatus + canTransition bool + } + newCanTransitionTestCase := func(from TransactionStatus, to TransactionStatus, canTransition bool) canTransitionTestCase { + namePrefix := "πŸ›£οΈ" + if !canTransition { + namePrefix = "🚧" + } + return canTransitionTestCase{ + name: fmt.Sprintf("[%s]%s->%s", namePrefix, from, to), + from: from, + to: to, + canTransition: canTransition, + } + } + + testCases := []struct { + name string + from TransactionStatus + to TransactionStatus + canTransition bool + }{ + // TransactionStatusPending -> ANY + newCanTransitionTestCase(TransactionStatusPending, TransactionStatusPending, false), + newCanTransitionTestCase(TransactionStatusPending, TransactionStatusProcessing, true), + newCanTransitionTestCase(TransactionStatusPending, TransactionStatusSuccess, false), + newCanTransitionTestCase(TransactionStatusPending, TransactionStatusError, false), + // TransactionStatusProcessing -> ANY + newCanTransitionTestCase(TransactionStatusProcessing, TransactionStatusPending, false), + newCanTransitionTestCase(TransactionStatusProcessing, TransactionStatusProcessing, false), + newCanTransitionTestCase(TransactionStatusProcessing, TransactionStatusSuccess, true), + newCanTransitionTestCase(TransactionStatusProcessing, TransactionStatusError, true), + // TransactionStatusSuccess -> ANY + newCanTransitionTestCase(TransactionStatusSuccess, TransactionStatusPending, false), + newCanTransitionTestCase(TransactionStatusSuccess, TransactionStatusProcessing, false), + newCanTransitionTestCase(TransactionStatusSuccess, TransactionStatusSuccess, false), + newCanTransitionTestCase(TransactionStatusSuccess, TransactionStatusError, false), + // TransactionStatusError -> ANY + newCanTransitionTestCase(TransactionStatusError, TransactionStatusPending, false), + newCanTransitionTestCase(TransactionStatusError, TransactionStatusProcessing, false), + newCanTransitionTestCase(TransactionStatusError, TransactionStatusSuccess, false), + newCanTransitionTestCase(TransactionStatusError, TransactionStatusError, false), + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.from.CanTransitionTo(tc.to) + if tc.canTransition { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.EqualError(t, err, fmt.Sprintf("cannot transition from %s to %s", tc.from, tc.to)) + } + }) + } +} diff --git a/internal/transactionsubmission/store/transactions.go b/internal/transactionsubmission/store/transactions.go new file mode 100644 index 000000000..c56937c5c --- /dev/null +++ b/internal/transactionsubmission/store/transactions.go @@ -0,0 +1,516 @@ +package store + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + "github.com/stellar/go/strkey" + "github.com/stellar/go/xdr" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" +) + +var ErrRecordNotFound = errors.New("record not found") + +type Transaction struct { + ID string `db:"id"` + // ExternalID contains an external ID for the transaction. This is used for reconciliation. + ExternalID string `db:"external_id"` + // Status is the status of the transaction. Don't change it directly and use the internal methods of the model instead. + Status TransactionStatus `db:"status"` + StatusMessage sql.NullString `db:"status_message"` + StatusHistory TransactionStatusHistory `db:"status_history"` + AssetCode string `db:"asset_code"` + AssetIssuer string `db:"asset_issuer"` + Amount float64 `db:"amount"` + Destination string `db:"destination"` + + CreatedAt *time.Time `db:"created_at"` + UpdatedAt *time.Time `db:"updated_at"` + // StartedAt is when the transaction was read from the queue into memory. + StartedAt *time.Time `db:"started_at"` + // SentAt is when the transaction was sent to the Stellar network. + SentAt *time.Time `db:"sent_at"` + // CompletedAt is when the transaction reached a terminal state, either SUCCESS or ERROR. + CompletedAt *time.Time `db:"completed_at"` + // SyncedAt is when the transaction was synced with SDP. + SyncedAt *time.Time `db:"synced_at"` + + AttemptsCount int `db:"attempts_count"` + StellarTransactionHash sql.NullString `db:"stellar_transaction_hash"` + // XDRSent is the EnvelopeXDR submitted when creating a Stellar transaction in the network. + XDRSent sql.NullString `db:"xdr_sent"` + // XDRReceived is the ResultXDR received from the Stellar network when attempting to create a transaction. + XDRReceived sql.NullString `db:"xdr_received"` + LockedAt *time.Time `db:"locked_at"` + // LockedUntilLedgerNumber is the ledger number after which the lock expires. It should be synched with the + // expiration ledger bound set in the Stellar transaction submitted to the blockchain, and the same value in the + // namesake column of the channel account model. + LockedUntilLedgerNumber sql.NullInt32 `db:"locked_until_ledger_number"` +} + +func (tx *Transaction) IsLocked(currentLedgerNumber int32) bool { + return tx.LockedUntilLedgerNumber.Valid && currentLedgerNumber <= tx.LockedUntilLedgerNumber.Int32 +} + +// validate checks if the transaction fields are valid and can be added to the DB. +func (tx *Transaction) validate() error { + if tx.ExternalID == "" { + return fmt.Errorf("external ID is required") + } + if len(tx.AssetCode) < 1 || len(tx.AssetCode) > 12 { + return fmt.Errorf("asset code must have between 1 and 12 characters") + } + if strings.ToLower(tx.AssetCode) != "xlm" { + if tx.AssetIssuer == "" { + return fmt.Errorf("asset issuer is required") + } + + if !strkey.IsValidEd25519PublicKey(tx.AssetIssuer) { + return fmt.Errorf("asset issuer %q is not a valid ed25519 public key", tx.AssetIssuer) + } + } + if tx.Amount <= 0 { + return fmt.Errorf("amount must be positive") + } + if !strkey.IsValidEd25519PublicKey(tx.Destination) { + return fmt.Errorf("destination %q is not a valid ed25519 public key", tx.Destination) + } + return nil +} + +type TransactionModel struct { + DBConnectionPool db.DBConnectionPool +} + +func NewTransactionModel(dbConnectionPool db.DBConnectionPool) *TransactionModel { + return &TransactionModel{DBConnectionPool: dbConnectionPool} +} + +// Insert adds a new Transaction to the database. +func (t *TransactionModel) Insert(ctx context.Context, tx Transaction) (*Transaction, error) { + transactions, err := t.BulkInsert(ctx, t.DBConnectionPool, []Transaction{tx}) + if err != nil { + return nil, fmt.Errorf("inserting single transaction: %w", err) + } + + return &transactions[0], nil +} + +// BulkInsert adds a batch of Transactions to the database and returns the inserted transactions. +func (t *TransactionModel) BulkInsert(ctx context.Context, sqlExec db.SQLExecuter, transactions []Transaction) ([]Transaction, error) { + if len(transactions) == 0 { + return nil, nil + } + + var queryBuilder strings.Builder + queryBuilder.WriteString("INSERT INTO submitter_transactions (external_id, asset_code, asset_issuer, amount, destination) VALUES ") + valueStrings := make([]string, 0, len(transactions)) + valueArgs := make([]interface{}, 0, len(transactions)*6) + + for _, transaction := range transactions { + if err := transaction.validate(); err != nil { + return nil, fmt.Errorf("validating transaction for insertion: %w", err) + } + valueStrings = append(valueStrings, "(?, ?, ?, ?, ?)") + valueArgs = append(valueArgs, + transaction.ExternalID, + transaction.AssetCode, + transaction.AssetIssuer, + transaction.Amount, + transaction.Destination, + ) + } + + var insertedTransctions []Transaction + queryBuilder.WriteString(strings.Join(valueStrings, ", ")) + queryBuilder.WriteString(" RETURNING *") + query := sqlExec.Rebind(queryBuilder.String()) + err := sqlExec.SelectContext(ctx, &insertedTransctions, query, valueArgs...) + if err != nil { + return nil, fmt.Errorf("inserting transactions: %w", err) + } + + return insertedTransctions, nil +} + +// Get gets a Transaction from the database. +func (t *TransactionModel) Get(ctx context.Context, txID string) (*Transaction, error) { + var transaction Transaction + q := ` + SELECT + * + FROM + submitter_transactions t + WHERE + t.id = $1 + ` + err := t.DBConnectionPool.GetContext(ctx, &transaction, q, txID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error querying transaction ID %s: %w", txID, err) + } + return &transaction, err +} + +func (t *TransactionModel) GetAllByPaymentIDs(ctx context.Context, paymentIDs []string) ([]*Transaction, error) { + var transactions []*Transaction + q := ` + SELECT + * + FROM + submitter_transactions t + WHERE + t.external_id = ANY($1) + ` + err := t.DBConnectionPool.SelectContext(ctx, &transactions, q, pq.Array(paymentIDs)) + if err != nil { + return nil, fmt.Errorf("error querying transactions: %w", err) + } + return transactions, nil +} + +// UpdateStatusToSuccess updates a Transaction's status to SUCCESS. Only succeeds if the current status is PROCESSING. +func (t *TransactionModel) UpdateStatusToSuccess(ctx context.Context, tx Transaction) (*Transaction, error) { + // verify if this state transition is valid: + err := tx.Status.CanTransitionTo(TransactionStatusSuccess) + if err != nil { + return nil, fmt.Errorf("attempting to transition transaction status to TransactionStatusSuccess: %w", err) + } + + var updatedTx Transaction + query := ` + UPDATE + submitter_transactions + SET + status = $1, + completed_at = NOW(), + status_history = array_append(status_history, create_submitter_transactions_status_history(NOW(), $1::transaction_status, NULL, stellar_transaction_hash, xdr_sent, xdr_received)) + WHERE + id = $2 + RETURNING + * + ` + err = t.DBConnectionPool.GetContext(ctx, &updatedTx, query, TransactionStatusSuccess, tx.ID) + if err != nil { + return nil, fmt.Errorf("updating transaction status to TransactionStatusSuccess: %w", err) + } + + return &updatedTx, nil +} + +// UpdateStatusToError updates a Transaction's status to ERROR. Only succeeds if the current status is PROCESSING. +func (t *TransactionModel) UpdateStatusToError(ctx context.Context, tx Transaction, message string) (*Transaction, error) { + // verify if this state transition is valid: + err := tx.Status.CanTransitionTo(TransactionStatusError) + if err != nil { + return nil, fmt.Errorf("attempting to transition transaction status to TransactionStatusError: %w", err) + } + + var updatedTx Transaction + query := ` + UPDATE + submitter_transactions + SET + status = $1, + completed_at = NOW(), + status_message = $2, + status_history = array_append(status_history, create_submitter_transactions_status_history(NOW(), $1::transaction_status, $2::text, stellar_transaction_hash, xdr_sent, xdr_received)) + WHERE + id = $3 + RETURNING + * + ` + err = t.DBConnectionPool.GetContext(ctx, &updatedTx, query, TransactionStatusError, message, tx.ID) + if err != nil { + return nil, fmt.Errorf("updating transaction status to TransactionStatusError: %w", err) + } + + return &updatedTx, nil +} + +func (t *TransactionModel) UpdateStellarTransactionHashAndXDRSent(ctx context.Context, txID string, txHash, txXDRSent string) (*Transaction, error) { + if len(txHash) != 64 { + return nil, fmt.Errorf("invalid transaction hash %q", txHash) + } + + var txEnvelope xdr.TransactionEnvelope + err := xdr.SafeUnmarshalBase64(txXDRSent, &txEnvelope) + if err != nil { + return nil, fmt.Errorf("invalid XDR envelope: %w", err) + } + + query := ` + UPDATE + submitter_transactions + SET + stellar_transaction_hash = $1::text, + xdr_sent = $2, + sent_at = NOW(), + status_history = array_append(status_history, create_submitter_transactions_status_history(NOW(), status, 'Updating Stellar Transaction Hash', $1::text, $2, xdr_received)), + attempts_count = attempts_count + 1 + WHERE + id = $3 + RETURNING + * + ` + var tx Transaction + err = t.DBConnectionPool.GetContext(ctx, &tx, query, txHash, txXDRSent, txID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error updating transaction hash: %w", err) + } + + return &tx, nil +} + +// UpdateStellarTransactionXDRReceived updates a Transaction's XDR received. +func (t *TransactionModel) UpdateStellarTransactionXDRReceived(ctx context.Context, txID string, xdrReceived string) (*Transaction, error) { + var txResult xdr.TransactionResult + err := xdr.SafeUnmarshalBase64(xdrReceived, &txResult) + if err != nil { + return nil, fmt.Errorf("invalid XDR result: %w", err) + } + + query := ` + UPDATE + submitter_transactions + SET + xdr_received = $1, + status_history = array_append(status_history, create_submitter_transactions_status_history(NOW(), status, 'Updating XDR Received', stellar_transaction_hash, xdr_sent, $1::text)) + WHERE + id = $2 + RETURNING + * + ` + var updatedTx Transaction + err = t.DBConnectionPool.GetContext(ctx, &updatedTx, query, xdrReceived, txID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("error updating transaction hash: %w", err) + } + + return &updatedTx, nil +} + +// GetTransactionBatchForUpdate returns a batch of transactions that are ready to be synced. Locks the rows for update. +func (t *TransactionModel) GetTransactionBatchForUpdate(ctx context.Context, dbTx db.DBTransaction, batchSize int) ([]*Transaction, error) { + if batchSize <= 0 { + return nil, fmt.Errorf("batch size must be greater than 0") + } + + transactions := []*Transaction{} + + query := ` + SELECT + * + FROM + submitter_transactions + WHERE + status IN ('SUCCESS', 'ERROR') + AND synced_at IS NULL + ORDER BY + completed_at ASC + LIMIT + $1 + FOR UPDATE SKIP LOCKED + ` + + err := dbTx.SelectContext(ctx, &transactions, query, batchSize) + if err != nil { + return nil, fmt.Errorf("getting transactions: %w", err) + } + + return transactions, nil +} + +// UpdateSyncedTransactions updates the synced_at field for the given transaction IDs. Returns an error if the number of +// updated rows is not equal to the number of provided transaction IDs. +func (t *TransactionModel) UpdateSyncedTransactions(ctx context.Context, dbTx db.DBTransaction, txIDs []string) error { + if len(txIDs) == 0 { + return fmt.Errorf("no transaction IDs provided") + } + + query := ` + UPDATE + submitter_transactions + SET + synced_at = NOW() + WHERE + id = ANY($1) + AND status = ANY($2) + ` + + allowedStatuses := []TransactionStatus{TransactionStatusSuccess, TransactionStatusError} + result, err := dbTx.ExecContext(ctx, query, pq.Array(txIDs), pq.Array(allowedStatuses)) + if err != nil { + return fmt.Errorf("updating transactions: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("getting rows affected: %w", err) + } + + if rowsAffected != int64(len(txIDs)) { + return fmt.Errorf("expected %d rows to be affected, got %d", len(txIDs), rowsAffected) + } + + return nil +} + +// queryFilterForLockedState returns a SQL query filter that can be used to filter transactions based on their locked +// state. +func (ca *TransactionModel) queryFilterForLockedState(locked bool, ledgerNumber int32) string { + if locked { + return fmt.Sprintf("(locked_until_ledger_number >= %d)", ledgerNumber) + } + return fmt.Sprintf("(locked_until_ledger_number IS NULL OR locked_until_ledger_number < %d)", ledgerNumber) +} + +// Lock locks the transaction with the provided transactionID. It returns a ErrRecordNotFound error if you try to lock a +// transaction that is already locked. +func (ca *TransactionModel) Lock(ctx context.Context, sqlExec db.SQLExecuter, transactionID string, currentLedger, nextLedgerLock int32) (*Transaction, error) { + q := fmt.Sprintf(` + UPDATE + submitter_transactions + SET + locked_at = NOW(), + locked_until_ledger_number = $1, + status = $2 + WHERE + id = $3 + AND %s + AND synced_at IS NULL + AND status = ANY($4) + RETURNING * + `, ca.queryFilterForLockedState(false, currentLedger)) + var transaction Transaction + allowedTxStatuses := []TransactionStatus{TransactionStatusPending, TransactionStatusProcessing} + err := sqlExec.GetContext(ctx, &transaction, q, nextLedgerLock, TransactionStatusProcessing, transactionID, pq.Array(allowedTxStatuses)) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("locking transaction %q: %w", transactionID, err) + } + + return &transaction, nil +} + +// Unlock lifts the lock from the transactionID with the provided publicKey. +func (ca *TransactionModel) Unlock(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*Transaction, error) { + q := ` + UPDATE + submitter_transactions + SET + locked_at = NULL, + locked_until_ledger_number = NULL + WHERE + id = $1 + RETURNING * + ` + var transaction Transaction + err := sqlExec.GetContext(ctx, &transaction, q, publicKey) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("unlocking transaction %q: %w", publicKey, err) + } + + return &transaction, nil +} + +// PrepareTransactionForReprocessing pushes the transaction with the provided transactionID back to the queue. +func (ca *TransactionModel) PrepareTransactionForReprocessing(ctx context.Context, sqlExec db.SQLExecuter, transactionID string) (*Transaction, error) { + q := ` + UPDATE + submitter_transactions + SET + locked_at = NULL, + locked_until_ledger_number = NULL, + stellar_transaction_hash = NULL, + xdr_sent = NULL, + xdr_received = NULL + WHERE + id = $1 + AND synced_at IS NULL + AND status = ANY($2) + RETURNING * + ` + var transaction Transaction + allowedTxStatuses := []TransactionStatus{TransactionStatusPending, TransactionStatusProcessing} + err := sqlExec.GetContext(ctx, &transaction, q, transactionID, pq.Array(allowedTxStatuses)) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRecordNotFound + } + return nil, fmt.Errorf("pushing transaction back to queue %q: %w", transactionID, err) + } + + return &transaction, nil +} + +var _ TransactionStore = &TransactionModel{} + +type TransactionStatusHistoryEntry struct { + Status string `json:"status"` + StatusMessage string `json:"status_message"` + Timestamp time.Time `json:"timestamp"` + StellarTransactionHash string `json:"stellar_transaction_hash"` + XDRSent string `json:"xdr_sent"` + XDRReceived string `json:"xdr_received"` +} + +type TransactionStatusHistory []TransactionStatusHistoryEntry + +// Value implements the driver.Valuer interface. +func (tsh TransactionStatusHistory) Value() (driver.Value, error) { + var statusHistoryJSON []string + for _, sh := range tsh { + shJSONBytes, err := json.Marshal(sh) + if err != nil { + return nil, fmt.Errorf("error converting status history to json for transaction: %w", err) + } + statusHistoryJSON = append(statusHistoryJSON, string(shJSONBytes)) + } + + return pq.Array(statusHistoryJSON).Value() +} + +// Scan implements the sql.Scanner interface. +func (tsh *TransactionStatusHistory) Scan(src interface{}) error { + var statusHistoryJSON []string + if err := pq.Array(&statusHistoryJSON).Scan(src); err != nil { + return fmt.Errorf("error scanning status history value: %w", err) + } + + for _, sh := range statusHistoryJSON { + var shEntry TransactionStatusHistoryEntry + err := json.Unmarshal([]byte(sh), &shEntry) + if err != nil { + return fmt.Errorf("error unmarshaling status_history column: %w", err) + } + *tsh = append(*tsh, shEntry) + } + + return nil +} + +var ( + _ sql.Scanner = (*TransactionStatusHistory)(nil) + _ driver.Valuer = (*TransactionStatusHistory)(nil) +) diff --git a/internal/transactionsubmission/store/transactions_test.go b/internal/transactionsubmission/store/transactions_test.go new file mode 100644 index 000000000..d0cc52373 --- /dev/null +++ b/internal/transactionsubmission/store/transactions_test.go @@ -0,0 +1,1142 @@ +package store + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stellar/go/keypair" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Transaction_IsLocked(t *testing.T) { + const currentLedgerNumber = 10 + + testCases := []struct { + name string + lockedUntilLedgerNumber sql.NullInt32 + wantResult bool + }{ + { + name: "returns false if lockedUntilLedgerNumber is null", + lockedUntilLedgerNumber: sql.NullInt32{}, + wantResult: false, + }, + { + name: "returns false if lockedUntilLedgerNumber is lower than currentLedgerNumber", + lockedUntilLedgerNumber: sql.NullInt32{Int32: currentLedgerNumber - 1, Valid: true}, + wantResult: false, + }, + { + name: "returns true if lockedUntilLedgerNumber is equal to currentLedgerNumber", + lockedUntilLedgerNumber: sql.NullInt32{Int32: currentLedgerNumber, Valid: true}, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := &Transaction{LockedUntilLedgerNumber: tc.lockedUntilLedgerNumber} + assert.Equal(t, tc.wantResult, tx.IsLocked(currentLedgerNumber)) + }) + } +} + +func Test_TransactionModel_Insert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + t.Run("return an error if the input parameters are invalid", func(t *testing.T) { + tx, err := txModel.Insert(ctx, Transaction{ExternalID: "external-id-1"}) + require.Error(t, err) + assert.EqualError(t, err, "inserting single transaction: validating transaction for insertion: asset code must have between 1 and 12 characters") + assert.Nil(t, tx) + }) + + t.Run("πŸŽ‰ successfully insert a new Transaction", func(t *testing.T) { + transaction, err := txModel.Insert(ctx, Transaction{ + ExternalID: "external-id-1", + AssetCode: "USDC", + AssetIssuer: "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", + Amount: 1, + Destination: "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + }) + require.NoError(t, err) + require.NotNil(t, transaction) + + refreshedTx, err := txModel.Get(ctx, transaction.ID) + require.NoError(t, err) + assert.Equal(t, transaction, refreshedTx) + + assert.Equal(t, "external-id-1", refreshedTx.ExternalID) + assert.Equal(t, "USDC", refreshedTx.AssetCode) + assert.Equal(t, "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", refreshedTx.AssetIssuer) + assert.Equal(t, float64(1), refreshedTx.Amount) + assert.Equal(t, "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", refreshedTx.Destination) + assert.Equal(t, TransactionStatusPending, refreshedTx.Status) + }) +} + +func Test_TransactionModel_BulkInsert(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + defer DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + t.Run("return nil with no error if the input slice is nil", func(t *testing.T) { + insertedTransactions, err := txModel.BulkInsert(ctx, dbConnectionPool, nil) + require.NoError(t, err) + assert.Nil(t, insertedTransactions) + }) + + t.Run("return nil with no error if the input slice is empty", func(t *testing.T) { + insertedTransactions, err := txModel.BulkInsert(ctx, dbConnectionPool, []Transaction{}) + require.NoError(t, err) + assert.Nil(t, insertedTransactions) + }) + + t.Run("return an error if the input parameters are invalid", func(t *testing.T) { + transactionsToInsert := []Transaction{{ExternalID: "external-id-1"}} + insertedTransactions, err := txModel.BulkInsert(ctx, dbConnectionPool, transactionsToInsert) + require.Error(t, err) + assert.EqualError(t, err, "validating transaction for insertion: asset code must have between 1 and 12 characters") + assert.Nil(t, insertedTransactions) + }) + + t.Run("πŸŽ‰ successfully inserts the transactions successfully", func(t *testing.T) { + incomingTx1 := Transaction{ + ExternalID: "external-id-1", + AssetCode: "USDC", + AssetIssuer: keypair.MustRandom().Address(), + Amount: 1, + Destination: keypair.MustRandom().Address(), + } + incomingTx2 := Transaction{ + ExternalID: "external-id-2", + AssetCode: "USDC", + AssetIssuer: keypair.MustRandom().Address(), + Amount: 2, + Destination: keypair.MustRandom().Address(), + } + insertedTransactions, err := txModel.BulkInsert(ctx, dbConnectionPool, []Transaction{incomingTx1, incomingTx2}) + require.NoError(t, err) + assert.NotNil(t, insertedTransactions) + assert.Len(t, insertedTransactions, 2) + + var insertedTx1, insertedTx2 Transaction + for _, tx := range insertedTransactions { + if tx.ExternalID == incomingTx1.ExternalID { + insertedTx1 = tx + } else if tx.ExternalID == incomingTx2.ExternalID { + insertedTx2 = tx + } else { + require.FailNow(t, "unexpected transaction: %v", tx) + } + } + + assert.Equal(t, incomingTx1.ExternalID, insertedTx1.ExternalID) + assert.Equal(t, incomingTx1.AssetCode, insertedTx1.AssetCode) + assert.Equal(t, incomingTx1.AssetIssuer, insertedTx1.AssetIssuer) + assert.Equal(t, incomingTx1.Amount, insertedTx1.Amount) + assert.Equal(t, incomingTx1.Destination, insertedTx1.Destination) + assert.Equal(t, TransactionStatusPending, insertedTx1.Status) + + assert.Equal(t, incomingTx2.ExternalID, insertedTx2.ExternalID) + assert.Equal(t, incomingTx2.AssetCode, insertedTx2.AssetCode) + assert.Equal(t, incomingTx2.AssetIssuer, insertedTx2.AssetIssuer) + assert.Equal(t, incomingTx2.Amount, insertedTx2.Amount) + assert.Equal(t, incomingTx2.Destination, insertedTx2.Destination) + assert.Equal(t, TransactionStatusPending, insertedTx2.Status) + }) +} + +func Test_TransactionModel_UpdateStatusToSuccess(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + testCases := []struct { + name string + transactionStatus TransactionStatus + wantErrContains string + }{ + { + name: "cannot transition PENDING->SUCCESS", + transactionStatus: TransactionStatusPending, + wantErrContains: "attempting to transition transaction status to TransactionStatusSuccess: cannot transition from PENDING to SUCCESS", + }, + { + name: "πŸŽ‰ successfully transition PROCESSING->SUCCESS", + transactionStatus: TransactionStatusProcessing, + }, + { + name: "cannot transition SUCCESS->SUCCESS", + transactionStatus: TransactionStatusSuccess, + wantErrContains: "attempting to transition transaction status to TransactionStatusSuccess: cannot transition from SUCCESS to SUCCESS", + }, + { + name: "cannot transition ERROR->SUCCESS", + transactionStatus: TransactionStatusError, + wantErrContains: "attempting to transition transaction status to TransactionStatusSuccess: cannot transition from ERROR to SUCCESS", + }, + } + + unphazedTx := CreateTransactionFixture( + t, + ctx, + dbConnectionPool, + uuid.NewString(), + "USDC", + "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + TransactionStatusPending, + 1.23, + ) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := CreateTransactionFixture( + t, + ctx, + dbConnectionPool, + uuid.NewString(), + "USDC", + "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + tc.transactionStatus, + 1.23, + ) + if (tc.transactionStatus != TransactionStatusSuccess) && (tc.transactionStatus != TransactionStatusError) { + assert.Empty(t, tx.CompletedAt) + } else { + assert.NotEmpty(t, tx.CompletedAt) + } + + updatedTx, err := txModel.UpdateStatusToSuccess(ctx, *tx) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + require.NoError(t, err) + assert.Equal(t, TransactionStatusSuccess, updatedTx.Status) + assert.NotEmpty(t, updatedTx.CompletedAt) + + // verify that the only fields that changed are updated_at, completed_at, status and status_history: + tx.UpdatedAt = updatedTx.UpdatedAt + tx.CompletedAt = updatedTx.CompletedAt + tx.Status = updatedTx.Status + tx.StatusHistory = append(TransactionStatusHistory{}, updatedTx.StatusHistory...) + assert.Equal(t, tx, updatedTx) + } + + // verify the unphazed transaction was not updated + refreshedUnphazedTx, err := txModel.Get(ctx, unphazedTx.ID) + require.NoError(t, err) + assert.Equal(t, unphazedTx, refreshedUnphazedTx) + }) + } +} + +func Test_TransactionModel_UpdateStatusToError(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + testCases := []struct { + name string + transactionStatus TransactionStatus + wantErrContains string + }{ + { + name: "cannot transition PENDING->ERROR", + transactionStatus: TransactionStatusPending, + wantErrContains: "attempting to transition transaction status to TransactionStatusError: cannot transition from PENDING to ERROR", + }, + { + name: "πŸŽ‰ successfully transition PROCESSING->ERROR", + transactionStatus: TransactionStatusProcessing, + }, + { + name: "cannot transition SUCCESS->ERROR", + transactionStatus: TransactionStatusSuccess, + wantErrContains: "attempting to transition transaction status to TransactionStatusError: cannot transition from SUCCESS to ERROR", + }, + { + name: "cannot transition ERROR->ERROR", + transactionStatus: TransactionStatusError, + wantErrContains: "attempting to transition transaction status to TransactionStatusError: cannot transition from ERROR to ERROR", + }, + } + + unphazedTx := CreateTransactionFixture( + t, + ctx, + dbConnectionPool, + uuid.NewString(), + "USDC", + "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + TransactionStatusPending, + 1.23, + ) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := CreateTransactionFixture( + t, + ctx, + dbConnectionPool, + uuid.NewString(), + "USDC", + "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + tc.transactionStatus, + 1.23, + ) + assert.Empty(t, tx.StatusMessage) + if (tc.transactionStatus != TransactionStatusSuccess) && (tc.transactionStatus != TransactionStatusError) { + assert.Empty(t, tx.CompletedAt) + } else { + assert.NotEmpty(t, tx.CompletedAt) + } + + const someErrMessage = "some error message" + updatedTx, err := txModel.UpdateStatusToError(ctx, *tx, someErrMessage) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + require.NoError(t, err) + assert.Equal(t, TransactionStatusError, updatedTx.Status) + assert.NotEmpty(t, updatedTx.CompletedAt) + + // verify that the only fields that changed are updated_at, completed_at, status, status_message and status history: + tx.UpdatedAt = updatedTx.UpdatedAt + tx.CompletedAt = updatedTx.CompletedAt + tx.Status = updatedTx.Status + tx.StatusMessage = sql.NullString{String: someErrMessage, Valid: true} + tx.StatusHistory = append(TransactionStatusHistory{}, updatedTx.StatusHistory...) + assert.Equal(t, tx, updatedTx) + } + + // verify the unphazed transaction was not updated + refreshedUnphazedTx, err := txModel.Get(ctx, unphazedTx.ID) + require.NoError(t, err) + assert.Equal(t, unphazedTx, refreshedUnphazedTx) + }) + } +} + +func Test_TransactionModel_UpdateStellarTransactionHashAndXDRSent(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + const txHash = "3389e9f0f1a65f19736cacf544c2e825313e8447f569233bb8db39aa607c8889" + const envelopeXDR = "AAAAAGL8HQvQkbK2HA3WVjRrKmjX00fG8sLI7m0ERwJW/AX3AAAACgAAAAAAAAABAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAArqN6LeOagjxMaUP96Bzfs9e0corNZXzBWJkFoK7kvkwAAAAAO5rKAAAAAAAAAAABVvwF9wAAAEAKZ7IPj/46PuWU6ZOtyMosctNAkXRNX9WCAI5RnfRk+AyxDLoDZP/9l3NvsxQtWj9juQOuoBlFLnWu8intgxQA" + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + + testCases := []struct { + name string + transaction Transaction + txHash string + xdrSent string + wantErrContains string + }{ + { + name: "returns an error if the size of the txHash if invalid", + txHash: "invalid-tx-hash", + wantErrContains: `invalid transaction hash "invalid-tx-hash"`, + }, + { + name: "returns an error if XDR is empty", + txHash: txHash, + wantErrContains: "invalid XDR envelope: decoding EnvelopeType: decoding EnvelopeType: xdr:DecodeInt: EOF while decoding 4 bytes - read: '[]'", + }, + { + name: "returns an error if XDR is not a valid base64 encoded", + txHash: txHash, + xdrSent: "not-base-64-encoded", + wantErrContains: "invalid XDR envelope: decoding EnvelopeType: decoding EnvelopeType: xdr:DecodeInt: illegal base64 data at input byte", + }, + { + name: "returns an error if XDR is not a transaction envelope", + txHash: txHash, + xdrSent: resultXDR, + wantErrContains: "invalid XDR envelope: decoding TransactionV0Envelope: decoding TransactionV0: decoding TimeBounds", + }, + { + name: "πŸŽ‰ successfully validate both the tx hash and the XDR envelope, and save them to the DB", + txHash: txHash, + xdrSent: envelopeXDR, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // create a new transaction + tx, err := txModel.Insert(ctx, Transaction{ + ExternalID: uuid.NewString(), + AssetCode: "USDC", + AssetIssuer: "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", + Amount: 1, + Destination: "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + }) + require.NoError(t, err) + require.NotNil(t, tx) + + // verify the transaction was created + originalTx, err := txModel.Get(ctx, tx.ID) + require.NoError(t, err) + + assert.False(t, originalTx.XDRSent.Valid) + assert.Equal(t, "", originalTx.XDRSent.String) + assert.False(t, originalTx.StellarTransactionHash.Valid) + assert.Equal(t, "", originalTx.StellarTransactionHash.String) + assert.Nil(t, originalTx.SentAt) + assert.Len(t, originalTx.StatusHistory, 1) + initialStatusHistory := originalTx.StatusHistory[0] + + updatedTx, err := txModel.UpdateStellarTransactionHashAndXDRSent(ctx, tx.ID, tc.txHash, tc.xdrSent) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Nil(t, updatedTx) + } else { + // check if object has been updated correctly + require.NoError(t, err) + assert.True(t, updatedTx.XDRSent.Valid) + assert.Equal(t, envelopeXDR, updatedTx.XDRSent.String) + assert.True(t, updatedTx.StellarTransactionHash.Valid) + assert.Equal(t, txHash, updatedTx.StellarTransactionHash.String) + assert.NotNil(t, updatedTx.SentAt) + assert.Equal(t, originalTx.AttemptsCount+1, updatedTx.AttemptsCount) + + // assert new status history info: + assert.Len(t, updatedTx.StatusHistory, 2) + newStatusHist := updatedTx.StatusHistory[1] + assert.Equal(t, string(updatedTx.Status), newStatusHist.Status) + assert.Equal(t, updatedTx.StellarTransactionHash.String, newStatusHist.StellarTransactionHash) + assert.Equal(t, updatedTx.XDRSent.String, newStatusHist.XDRSent) + assert.Empty(t, updatedTx.XDRReceived) + wantStatusHistory := TransactionStatusHistory{initialStatusHistory, newStatusHist} + + // retrieve the transaction from the database and check if values are updated + refreshedTx, err := txModel.Get(ctx, tx.ID) + require.NoError(t, err) + assert.Equal(t, updatedTx, refreshedTx) + + // make sure only the expected fields were updated: + originalTx.XDRSent = refreshedTx.XDRSent + originalTx.StellarTransactionHash = refreshedTx.StellarTransactionHash + originalTx.SentAt = refreshedTx.SentAt + originalTx.UpdatedAt = refreshedTx.UpdatedAt + originalTx.StatusHistory = wantStatusHistory + originalTx.AttemptsCount += 1 + assert.Equal(t, refreshedTx, originalTx) + } + }) + } +} + +func Test_TransactionModel_UpdateStellarTransactionXDRReceived(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + const envelopeXDR = "AAAAAGL8HQvQkbK2HA3WVjRrKmjX00fG8sLI7m0ERwJW/AX3AAAACgAAAAAAAAABAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAArqN6LeOagjxMaUP96Bzfs9e0corNZXzBWJkFoK7kvkwAAAAAO5rKAAAAAAAAAAABVvwF9wAAAEAKZ7IPj/46PuWU6ZOtyMosctNAkXRNX9WCAI5RnfRk+AyxDLoDZP/9l3NvsxQtWj9juQOuoBlFLnWu8intgxQA" + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + + testCases := []struct { + name string + transaction Transaction + xdrReceived string + wantErrContains string + }{ + { + name: "returns an error if XDR is empty", + xdrReceived: "", + wantErrContains: "invalid XDR result: decoding Int64: decoding Hyper: xdr:DecodeHyper: EOF while decoding 8 bytes - read: '[]'", + }, + { + name: "returns an error if XDR is not a valid base64 encoded", + xdrReceived: "not-base-64-encoded", + wantErrContains: "invalid XDR result: decoding Int64: decoding Hyper: xdr:DecodeHyper: illegal base64 data", + }, + { + name: "returns an error if XDR is not a transaction envelope", + xdrReceived: envelopeXDR, + wantErrContains: "invalid XDR result: decoding TransactionResultResult: decoding TransactionResultCode: '-795757898' is not a valid TransactionResultCode enum value", + }, + { + name: "πŸŽ‰ successfully validate a transaction result and save it in the DB", + xdrReceived: resultXDR, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // create a new transaction + tx, err := txModel.Insert(ctx, Transaction{ + ExternalID: uuid.NewString(), + AssetCode: "USDC", + AssetIssuer: "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", + Amount: 1, + Destination: "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + }) + require.NoError(t, err) + require.NotNil(t, tx) + + assert.Equal(t, false, tx.XDRReceived.Valid) + assert.Equal(t, "", tx.XDRReceived.String) + + updatedTx, err := txModel.UpdateStellarTransactionXDRReceived(ctx, tx.ID, tc.xdrReceived) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + // check if object has been updated correctly + require.NoError(t, err) + assert.Equal(t, true, updatedTx.XDRReceived.Valid) + assert.Equal(t, resultXDR, updatedTx.XDRReceived.String) + + // retrieve the transaction from the database and check if values are updated + refreshedTx, err := txModel.Get(ctx, tx.ID) + require.NoError(t, err) + assert.Equal(t, refreshedTx, updatedTx) + } + }) + } +} + +func Test_Transaction_validate(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + require.NoError(t, err) + + testCases := []struct { + name string + transaction Transaction + wantErrContains string + }{ + { + name: "validate ExternalID", + transaction: Transaction{}, + wantErrContains: "external ID is required", + }, + { + name: "validate AssetCode (min size)", + transaction: Transaction{ + ExternalID: "123", + }, + wantErrContains: "asset code must have between 1 and 12 characters", + }, + { + name: "validate AssetCode (max size)", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "1234567890123", + }, + wantErrContains: "asset code must have between 1 and 12 characters", + }, + { + name: "validate AssetIssuer (cannot be nil)", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "USDC", + }, + wantErrContains: "asset issuer is required", + }, + { + name: "validate AssetIssuer (not a valid public key)", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "USDC", + AssetIssuer: "invalid-issuer", + }, + wantErrContains: `asset issuer "invalid-issuer" is not a valid ed25519 public key`, + }, + { + name: "validate Amount", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "USDC", + AssetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + }, + wantErrContains: "amount must be positive", + }, + { + name: "validate Destination", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "USDC", + AssetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + Amount: 100.0, + Destination: "invalid-destination", + }, + wantErrContains: `destination "invalid-destination" is not a valid ed25519 public key`, + }, + { + name: "πŸŽ‰ successfully validate USDC transaction", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "USDC", + AssetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + Amount: 100.0, + Destination: "GDUCE34WW5Z34GMCEPURYANUCUP47J6NORJLKC6GJNMDLN4ZI4PMI2MG", + }, + }, + { + name: "πŸŽ‰ successfully validate XLM transaction", + transaction: Transaction{ + ExternalID: "123", + AssetCode: "xLm", + Amount: 100.0, + Destination: "GDUCE34WW5Z34GMCEPURYANUCUP47J6NORJLKC6GJNMDLN4ZI4PMI2MG", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.transaction.validate() + if tc.wantErrContains == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } + }) + } +} + +func Test_TransactionModel_GetTransactionBatchForUpdate(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + testCase := []struct { + name string + transactionStatus TransactionStatus + shouldBeFound bool + batchSize int + wantErrContains string + }{ + { + name: "batchSize must be >= 0", + transactionStatus: TransactionStatusSuccess, + batchSize: 0, + wantErrContains: "batch size must be greater than 0", + shouldBeFound: false, + }, + { + name: "no transactions found (empty database)", + transactionStatus: "", + batchSize: 100, + shouldBeFound: false, + }, + { + name: "no transactions found (PENDING)", + transactionStatus: TransactionStatusPending, + batchSize: 100, + shouldBeFound: false, + }, + { + name: "no transactions found (PROCESSING)", + transactionStatus: TransactionStatusProcessing, + batchSize: 100, + shouldBeFound: false, + }, + { + name: "πŸŽ‰ transactions successfully found (SUCCESS)", + transactionStatus: TransactionStatusSuccess, + batchSize: 100, + shouldBeFound: true, + }, + { + name: "πŸŽ‰ transactions successfully found (ERROR)", + transactionStatus: TransactionStatusError, + batchSize: 100, + shouldBeFound: true, + }, + } + + const txCount = 3 + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + defer func() { + err = dbTx.Rollback() + require.NoError(t, err) + }() + + var transactions []*Transaction + if tc.transactionStatus != "" { + // create transactions and get their IDs + transactions = CreateTransactionFixtures( + t, + ctx, + dbTx, + txCount, + "USDC", + "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + tc.transactionStatus, + 1.2, + ) + } + var txIDs []string + for _, tx := range transactions { + txIDs = append(txIDs, tx.ID) + } + + foundTransactions, err := txModel.GetTransactionBatchForUpdate(ctx, dbTx, tc.batchSize) + if tc.wantErrContains == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } + + var foundTxIDs []string + for _, tx := range foundTransactions { + foundTxIDs = append(foundTxIDs, tx.ID) + } + + if !tc.shouldBeFound { + assert.Equal(t, 0, len(foundTransactions)) + } else { + assert.Equal(t, txCount, len(foundTransactions)) + assert.ElementsMatch(t, txIDs, foundTxIDs) + } + }) + } + + DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) +} + +func Test_TransactionModel_UpdateSyncedTransactions(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := NewTransactionModel(dbConnectionPool) + + testCase := []struct { + name string + shouldSendEmptyIDs bool + shouldSendInvalidIDs bool + transactionStatus TransactionStatus + wantErrContains string + }{ + { + name: "rerturn an error if txIDs is empty", + shouldSendEmptyIDs: true, + wantErrContains: "no transaction IDs provided", + }, + { + name: "rerturn an error if the IDs sent don't exist", + shouldSendInvalidIDs: true, + wantErrContains: "expected 1 rows to be affected, got 0", + }, + { + name: "rerturn an error if the IDs sent were not ready to be synched (PENDING)", + transactionStatus: TransactionStatusPending, + wantErrContains: "expected 3 rows to be affected, got 0", + }, + { + name: "rerturn an error if the IDs sent were not ready to be synched (PROCESSING)", + transactionStatus: TransactionStatusProcessing, + wantErrContains: "expected 3 rows to be affected, got 0", + }, + { + name: "πŸŽ‰ successfully set the status of transactions to synched (SUCCESS)", + transactionStatus: TransactionStatusSuccess, + }, + { + name: "πŸŽ‰ successfully set the status of transactions to synched (ERROR)", + transactionStatus: TransactionStatusError, + }, + } + + const txCount = 3 + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + defer func() { + err = dbTx.Rollback() + require.NoError(t, err) + }() + + // create transactions and get their IDs + var txIDs []string + if tc.shouldSendEmptyIDs { + txIDs = []string{} + } else if tc.shouldSendInvalidIDs { + txIDs = []string{"invalid-id"} + } else { + transactions := CreateTransactionFixtures( + t, + ctx, + dbTx, + txCount, + "USDC", + "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + "GBHNIYGWZUAVZX7KTLVSMILBXJMUACVO6XBEKIN6RW7AABDFH6S7GK2Y", + tc.transactionStatus, + 1.2, + ) + for _, tx := range transactions { + txIDs = append(txIDs, tx.ID) + } + } + + err = txModel.UpdateSyncedTransactions(ctx, dbTx, txIDs) + + // count the number of transactions that were synched + var count int + countErr := dbTx.GetContext(ctx, &count, "SELECT COUNT(*) FROM submitter_transactions WHERE synced_at IS NOT NULL") + require.NoError(t, countErr) + + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + assert.Equal(t, 0, count) + } else { + require.NoError(t, err) + assert.Equal(t, txCount, count) + } + }) + } + + DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) +} + +func Test_TransactionModel_queryFilterForLockedState(t *testing.T) { + txModel := &TransactionModel{} + + testCases := []struct { + name string + locked bool + ledgerNumber int32 + wantFilter string + }{ + { + name: "locked to ledgerNumber=10", + locked: true, + ledgerNumber: 10, + wantFilter: "(locked_until_ledger_number >= 10)", + }, + { + name: "unlocked or expired on ledgerNumber=20", + locked: false, + ledgerNumber: 20, + wantFilter: "(locked_until_ledger_number IS NULL OR locked_until_ledger_number < 20)", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotFilter := txModel.queryFilterForLockedState(tc.locked, tc.ledgerNumber) + assert.Equal(t, tc.wantFilter, gotFilter) + }) + } +} + +func Test_TransactionModel_Lock(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + transactionModel := TransactionModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 10 + const nextLedgerLock int32 = 20 + + testCases := []struct { + name string + initialLockedAt sql.NullTime + initialSyncedAt sql.NullTime + initialStatus TransactionStatus + initialLockedUntilLedger sql.NullInt32 + expectedErrContains string + }{ + { + name: "πŸŽ‰ successfully locks transaction without any previous lock (PENDING)", + initialStatus: TransactionStatusPending, + }, + { + name: "πŸŽ‰ successfully locks transaction without any previous lock (PROCESSING)", + initialStatus: TransactionStatusProcessing, + }, + { + name: "πŸŽ‰ successfully locks transaction with lock expired", + initialStatus: TransactionStatusPending, + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger - 1, Valid: true}, + }, + { + name: "🚧 cannot be locked again if still locked", + initialStatus: TransactionStatusPending, + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger, Valid: true}, + expectedErrContains: ErrRecordNotFound.Error(), + }, + { + name: "🚧 cannot be locked if the status is SUCCESS", + initialStatus: TransactionStatusSuccess, + expectedErrContains: ErrRecordNotFound.Error(), + }, + { + name: "🚧 cannot be locked if the status is ERROR", + initialStatus: TransactionStatusError, + expectedErrContains: ErrRecordNotFound.Error(), + }, + { + name: "🚧 cannot be locked if siced_at is not empty", + initialStatus: TransactionStatusPending, + initialSyncedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + expectedErrContains: ErrRecordNotFound.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := CreateTransactionFixture(t, ctx, dbConnectionPool, uuid.NewString(), "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", tc.initialStatus, 1) + q := `UPDATE submitter_transactions SET locked_at = $1, locked_until_ledger_number = $2, synced_at = $3, status = $4 WHERE id = $5` + _, err := dbConnectionPool.ExecContext(ctx, q, tc.initialLockedAt, tc.initialLockedUntilLedger, tc.initialSyncedAt, tc.initialStatus, tx.ID) + require.NoError(t, err) + + tx, err = transactionModel.Lock(ctx, dbConnectionPool, tx.ID, currentLedger, nextLedgerLock) + + if tc.expectedErrContains == "" { + require.NoError(t, err) + tx, err = transactionModel.Get(ctx, tx.ID) + require.NoError(t, err) + assert.NotNil(t, tx.LockedAt) + assert.True(t, tx.LockedUntilLedgerNumber.Valid) + assert.Equal(t, nextLedgerLock, tx.LockedUntilLedgerNumber.Int32) + assert.Equal(t, TransactionStatusProcessing, tx.Status) + + var txRefreshed *Transaction + txRefreshed, err = transactionModel.Get(ctx, tx.ID) + require.NoError(t, err) + require.Equal(t, *txRefreshed, *tx) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, tc.expectedErrContains) + } + + DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + }) + } +} + +func Test_TransactionModel_Unlock(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + transactionModel := TransactionModel{DBConnectionPool: dbConnectionPool} + + const currentLedger int32 = 10 + + testCases := []struct { + name string + initialLockedAt sql.NullTime + initialSyncedAt sql.NullTime + initialStatus TransactionStatus + initialLockedUntilLedger sql.NullInt32 + }{ + { + name: "πŸŽ‰ successfully locks transaction without any previous lock", + initialStatus: TransactionStatusPending, + }, + { + name: "πŸŽ‰ successfully locks transaction with lock expired", + initialStatus: TransactionStatusPending, + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger - 1, Valid: true}, + }, + { + name: "πŸŽ‰ successfully unlocks locked transaction", + initialStatus: TransactionStatusPending, + initialLockedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + initialLockedUntilLedger: sql.NullInt32{Int32: currentLedger, Valid: true}, + }, + { + name: "πŸŽ‰ successfully unlocks transaction with status is SUCCESS", + initialStatus: TransactionStatusSuccess, + }, + { + name: "πŸŽ‰ successfully unlocks transaction with status is ERROR", + initialStatus: TransactionStatusError, + }, + { + name: "πŸŽ‰ successfully unlocks transaction with siced_at not empty", + initialStatus: TransactionStatusPending, + initialSyncedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := CreateTransactionFixture(t, ctx, dbConnectionPool, uuid.NewString(), "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", tc.initialStatus, 1) + q := `UPDATE submitter_transactions SET locked_at = $1, locked_until_ledger_number = $2, synced_at = $3, status = $4 WHERE id = $5` + _, err := dbConnectionPool.ExecContext(ctx, q, tc.initialLockedAt, tc.initialLockedUntilLedger, tc.initialSyncedAt, tc.initialStatus, tx.ID) + require.NoError(t, err) + + tx, err = transactionModel.Unlock(ctx, dbConnectionPool, tx.ID) + require.NoError(t, err) + + tx, err = transactionModel.Get(ctx, tx.ID) + require.NoError(t, err) + assert.Nil(t, tx.LockedAt) + assert.False(t, tx.LockedUntilLedgerNumber.Valid) + + var txRefreshed *Transaction + txRefreshed, err = transactionModel.Get(ctx, tx.ID) + require.NoError(t, err) + require.Equal(t, *txRefreshed, *tx) + + DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + }) + } +} + +func Test_TransactionModel_PrepareTransactionForReprocessing(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + transactionModel := NewTransactionModel(dbConnectionPool) + + testCases := []struct { + name string + status TransactionStatus + synchedAt sql.NullTime + wantError error + }{ + { + name: "cannot mark for reporcessing if the status is SUCCESS", + status: TransactionStatusSuccess, + wantError: ErrRecordNotFound, + }, + { + name: "cannot mark for reporcessing if the status is ERROR", + status: TransactionStatusError, + wantError: ErrRecordNotFound, + }, + { + name: "cannot mark for reporcessing if synced_at is not empty", + status: TransactionStatusProcessing, + synchedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + wantError: ErrRecordNotFound, + }, + { + name: "πŸŽ‰ successfully mark as processing if tx is PENDING and not synced transaction", + status: TransactionStatusPending, + }, + { + name: "πŸŽ‰ successfully mark as processing if tx is PROCESSING and not synced transaction", + status: TransactionStatusProcessing, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + const lockedUntilLedger = 2 + + // create and prepare the transaction: + tx := CreateTransactionFixture(t, ctx, dbConnectionPool, uuid.NewString(), "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", tc.status, 1) + q := `UPDATE submitter_transactions SET status = $1, synced_at = $2, locked_at = NOW(), locked_until_ledger_number=$3 WHERE id = $4` + _, err = dbConnectionPool.ExecContext(ctx, q, tc.status, tc.synchedAt, lockedUntilLedger, tx.ID) + require.NoError(t, err) + + // mark the transaction for reprocessing: + updatedTx, err := transactionModel.PrepareTransactionForReprocessing(ctx, dbConnectionPool, tx.ID) + + // check the result: + if tc.wantError != nil { + require.Error(t, err) + assert.Equal(t, tc.wantError, err) + assert.Nil(t, updatedTx) + } else { + require.NoError(t, err) + assert.Equal(t, tc.status, updatedTx.Status) + assert.Nil(t, updatedTx.SyncedAt) + + // Check if only the expected fields were updated: + assert.Nil(t, updatedTx.LockedAt) + assert.False(t, updatedTx.LockedUntilLedgerNumber.Valid) + assert.False(t, updatedTx.StellarTransactionHash.Valid) + assert.False(t, updatedTx.XDRSent.Valid) + assert.False(t, updatedTx.XDRReceived.Valid) + + // Check if the returned transaction is exactly the same as a fresh one from the DB: + refreshedTx, err := transactionModel.Get(ctx, tx.ID) + require.NoError(t, err) + require.Equal(t, refreshedTx, updatedTx) + } + }) + } +} diff --git a/internal/transactionsubmission/transaction_worker.go b/internal/transactionsubmission/transaction_worker.go new file mode 100644 index 000000000..9201165fa --- /dev/null +++ b/internal/transactionsubmission/transaction_worker.go @@ -0,0 +1,544 @@ +package transactionsubmission + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/strkey" + "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + tssUtils "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "golang.org/x/exp/slices" +) + +// Review these TODOs originally created by Stephen: +// TODO - memo/memoType not supported yet - [SDP-463] +// TODO - re-enable metrics/observer – [SDP-772] + +type TxJob store.ChannelTransactionBundle + +func (job TxJob) String() string { + return fmt.Sprintf("TxJob{ChannelAccount: %q, Transaction: %q, LockedUntilLedgerNumber: \"%d\"}", job.ChannelAccount.PublicKey, job.Transaction.ID, job.LockedUntilLedgerNumber) +} + +type TransactionWorker struct { + dbConnectionPool db.DBConnectionPool + txModel store.TransactionStore + chAccModel store.ChannelAccountStore + engine *engine.SubmitterEngine + sigService engine.SignatureService + maxBaseFee int + crashTrackerClient crashtracker.CrashTrackerClient + txProcessingLimiter *engine.TransactionProcessingLimiter +} + +func NewTransactionWorker( + dbConnectionPool db.DBConnectionPool, + txModel *store.TransactionModel, + chAccModel *store.ChannelAccountModel, + engine *engine.SubmitterEngine, + sigService engine.SignatureService, + maxBaseFee int, + crashTrackerClient crashtracker.CrashTrackerClient, + txProcessingLimiter *engine.TransactionProcessingLimiter, +) (TransactionWorker, error) { + if dbConnectionPool == nil { + return TransactionWorker{}, fmt.Errorf("dbConnectionPool cannot be nil") + } + + if txModel == nil { + return TransactionWorker{}, fmt.Errorf("txModel cannot be nil") + } + + if chAccModel == nil { + return TransactionWorker{}, fmt.Errorf("chAccModel cannot be nil") + } + + if tssUtils.IsEmpty(engine) { + return TransactionWorker{}, fmt.Errorf("engine cannot be nil") + } + + if tssUtils.IsEmpty(sigService) { + return TransactionWorker{}, fmt.Errorf("sigService cannot be nil") + } + + if maxBaseFee < txnbuild.MinBaseFee { + return TransactionWorker{}, fmt.Errorf("maxBaseFee must be greater than or equal to %d", txnbuild.MinBaseFee) + } + + if crashTrackerClient == nil { + return TransactionWorker{}, fmt.Errorf("crashTrackerClient cannot be nil") + } + + if txProcessingLimiter == nil { + return TransactionWorker{}, fmt.Errorf("txProcessingLimiter cannot be nil") + } + + return TransactionWorker{ + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: engine, + sigService: sigService, + maxBaseFee: maxBaseFee, + crashTrackerClient: crashTrackerClient, + txProcessingLimiter: txProcessingLimiter, + }, nil +} + +func (tw *TransactionWorker) Run(ctx context.Context, txJob *TxJob) { + err := tw.runJob(ctx, txJob) + if err != nil { + log.Ctx(ctx).Errorf("Handle unexpected error: %v", err) + } +} + +// TODO: add unit tests and godoc to this function +func (tw *TransactionWorker) runJob(ctx context.Context, txJob *TxJob) error { + err := tw.validateJob(txJob) + if err != nil { + return fmt.Errorf("validating job: %w", err) + } + + if txJob == nil { + return fmt.Errorf("received nil transaction job") + } else if txJob.Transaction.StellarTransactionHash.Valid { + return tw.reconcileSubmittedTransaction(ctx, txJob) + } else { + return tw.processTransactionSubmission(ctx, txJob) + } +} + +// TODO: add tests +// handleFailedTransaction will wrap up the job when the transaction was submitted to the network but failed. +// This method will only return an error if something goes wromg when handling the result and marking the transaction as ERROR. +func (tw *TransactionWorker) handleFailedTransaction(ctx context.Context, txJob *TxJob, hTxResp horizon.Transaction, hErr error) error { + log.Ctx(ctx).Errorf("πŸ”΄ Error processing job: %v", hErr) + + err := tw.saveResponseXDRIfPresent(ctx, txJob, hTxResp) + if err != nil { + return fmt.Errorf("saving response XDR: %w", err) + } + + var shouldMarkAsError bool + var hErrWrapper *utils.HorizonErrorWrapper + if errors.As(hErr, &hErrWrapper) { + tw.txProcessingLimiter.AdjustLimitIfNeeded(hErrWrapper) + + if hErrWrapper.ResultCodes != nil { + // TODO: move this logic inside the HorizonErrorWrapper + // ref: https://developers.stellar.org/api/horizon/errors/result-codes/ + failedTxErrCodes := []string{ + "tx_bad_auth", + "tx_bad_auth_extra", + "tx_insufficient_balance", + } + if slices.Contains(failedTxErrCodes, hErrWrapper.ResultCodes.TransactionCode) || slices.Contains(failedTxErrCodes, hErrWrapper.ResultCodes.InnerTransactionCode) { + shouldMarkAsError = true + } + + // TODO: move this logic inside the HorizonErrorWrapper + // ref: https://developers.stellar.org/api/horizon/errors/result-codes/ + failedOpCodes := []string{ + "op_bad_auth", + "op_underfunded", + "op_src_not_authorized", + "op_no_destination", + "op_no_trust", + "op_line_full", + "op_not_authorized", + "op_no_issuer", + } + if !shouldMarkAsError { + for _, opResult := range hErrWrapper.ResultCodes.OperationCodes { + if slices.Contains(failedOpCodes, opResult) { + shouldMarkAsError = true + break + } + } + } + + if shouldMarkAsError { + var updatedTx *store.Transaction + updatedTx, err = tw.txModel.UpdateStatusToError(ctx, txJob.Transaction, hErrWrapper.Error()) + if err != nil { + return fmt.Errorf("updating transaction status to error: %w", err) + } + + txJob.Transaction = *updatedTx + } + } + } + + // TODO: call MonitorService if needed + // TODO: call crashTrackerClient if needed + // TODO: op_bad_auth, tx_bad_auth, tx_bad_auth_extra are big problems that need to be reported accordingly + // TODO: tx_bad_seq is a big problem that needs to be reported accordingly + + // {Old TSS approach} -> {new approach}: + // - `504`: {retry in memory} -> {marked for retry} (pause/jitter could come later) + // - `429`: {paused and marked for retry} -> {marked for retry} (pause/jitter could come later) + // - `400 - tx_insufficient_fee` {marked for retry with exponential jitter until max_retry is reached} -> {marked for retry forever} (pause/jitter could come later) + // - `400 - tx_bad_seq` {marked as failed} -> {marked for retry and reported to crash tracker and observer} + // - `400 - tx_too_late` (bounds expired) {marked as failed} -> {marked for retry and reported to crash tracker and observer} + // - `400 - ???`: {marked as failed} -> {marked for retry and reported to crash tracker and observer} + // - unsupported error: {marked as failed} -> {marked for retry and reported to crash tracker and observer} + + // Some ideas for error handling (ref: https://developers.stellar.org/api/horizon/errors/result-codes/): + // BadAuthentication(): + // op_bad_auth (in result_codes.operations) + // tx_bad_auth (in result_codes.(inner_)transaction) + // tx_bad_auth_extra (in result_codes.(inner_)transaction) + // + // NotEnoughLumens(): + // op_underfunded (in result_codes.operations) + // tx_insufficient_balance (in result_codes.(inner_)transaction) + // + // SendingAccountIsBlocked() + // op_src_not_authorized (in result_codes.operations) + // + // DestinationAccountNotFound(): + // op_no_destination (in result_codes.operations) + // + // DesinationIsMissingTrustlineOrLimit(): + // op_no_trust (in result_codes.operations) + // op_line_full (in result_codes.operations) + // + // DestinationAccountIsBlocked(): + // op_not_authorized (in result_codes.operations) + // + // NonExistentAsset(): + // op_no_issuer (in result_codes.operations) + + err = tw.unlockJob(ctx, txJob) + if err != nil { + return fmt.Errorf("unlocking job: %w", err) + } + + return nil +} + +// TODO: add tests +// unlockJob will unlock the channel account and transaction instantaneously, so they can be made available ASAP. If +// this method is not called, the algorithm will fall back to get these resources qutomatically unlocked when their +// `locked-to-ledger` expire. +func (tw *TransactionWorker) unlockJob(ctx context.Context, txJob *TxJob) error { + _, err := tw.chAccModel.Unlock(ctx, tw.dbConnectionPool, txJob.ChannelAccount.PublicKey) + if err != nil { + return fmt.Errorf("unlocking channel account: %w", err) + } + + _, err = tw.txModel.Unlock(ctx, tw.dbConnectionPool, txJob.Transaction.ID) + if err != nil { + return fmt.Errorf("unlocking transaction: %w", err) + } + + return nil +} + +// handleSuccessfulTransaction will wrap up the job when the transaction has been successfully submitted to the network. +// This method will only return an error if something goes wromg when handling the result and marking the transaction as SUCCESS. +func (tw *TransactionWorker) handleSuccessfulTransaction(ctx context.Context, txJob *TxJob, hTxResp horizon.Transaction) error { + err := tw.saveResponseXDRIfPresent(ctx, txJob, hTxResp) + if err != nil { + return fmt.Errorf("saving response XDR: %w", err) + } + if !hTxResp.Successful { + return fmt.Errorf("transaction was not successful for some reason") + } + + _, err = tw.txModel.UpdateStatusToSuccess(ctx, txJob.Transaction) + if err != nil { + return utils.NewTransactionStatusUpdateError("SUCCESS", txJob.Transaction.ID, false, err) + } + + err = tw.unlockJob(ctx, txJob) + if err != nil { + return fmt.Errorf("unlocking job: %w", err) + } + + log.Ctx(ctx).Infof("πŸŽ‰ Successfully processed transaction job %v", txJob) + return nil +} + +// reconcileSubmittedTransaction will check the status of a previously submitted transaction and handle it accordingly. +// If the transaction was successful, it will be marked as such and the job will be unlocked. +// If the transaction failed, it will be marked for resubmission. +func (tw *TransactionWorker) reconcileSubmittedTransaction(ctx context.Context, txJob *TxJob) error { + log.Ctx(ctx).Infof("πŸ” Reconciling previously submitted transaction %v...", txJob) + + err := tw.validateJob(txJob) + if err != nil { + return fmt.Errorf("validating bundle: %w", err) + } + + txHash := txJob.Transaction.StellarTransactionHash.String + txDetail, err := tw.engine.HorizonClient.TransactionDetail(txHash) + hWrapperErr := utils.NewHorizonErrorWrapper(err) + if err == nil && txDetail.Successful { + err = tw.handleSuccessfulTransaction(ctx, txJob, txDetail) + if err != nil { + return fmt.Errorf("handling successful transaction: %w", err) + } + return nil + } else if (err == nil && !txDetail.Successful) || hWrapperErr.IsNotFound() { + // Unsuccesful hash: 98d3549076b119dbda42c17c2310d04666ef35524397ad3decb773ef1cebab1e + // Nonexistent hash: 3389e9f0f1a65f19736cacf544c2e825313e8447f569233bb8db39aa607c8889 + log.Ctx(ctx).Warnf("Previous transaction didn't make through, marking %v for resubmission...", txJob) + + _, err = tw.txModel.PrepareTransactionForReprocessing(ctx, tw.dbConnectionPool, txJob.Transaction.ID) + if err != nil { + return fmt.Errorf("pushing back transaction to queue: %w", err) + } + + err = tw.unlockJob(ctx, txJob) + if err != nil { + return fmt.Errorf("unlocking job: %w", err) + } + } else { + // Invalid hash: 123 + log.Ctx(ctx).Warnf("received unexpected horizon error: %v", hWrapperErr) + return fmt.Errorf("unexpected error: %w", hWrapperErr) + } + + return nil +} + +func (tw *TransactionWorker) processTransactionSubmission(ctx context.Context, txJob *TxJob) error { + log.Ctx(ctx).Infof("🚧 Processing transaction submission for job %v...", txJob) + + // STEP 1: validate bundle + err := tw.validateJob(txJob) + if err != nil { + return fmt.Errorf("validating bundle: %w", err) + } + + // STEP 2: prepare transaction for processing + feeBumpTx, err := tw.prepareForSubmission(ctx, txJob) + if err != nil { + return fmt.Errorf("preparing bundle for processing: %w", err) + } + + // STEP 3: process transaction + err = tw.submit(ctx, txJob, feeBumpTx) + if err != nil { + return fmt.Errorf("processing bundle: %w", err) + } + + return nil +} + +// validateJob will check if the job is valid for processing or reconciliation. +func (tw *TransactionWorker) validateJob(txJob *TxJob) error { + allowedStatuses := []store.TransactionStatus{store.TransactionStatusPending, store.TransactionStatusProcessing} + if !slices.Contains(allowedStatuses, txJob.Transaction.Status) { + return fmt.Errorf("invalid transaction status: %v", txJob.Transaction.Status) + } + + // TODO: make sure we're handling 429s upstream + currentLedgerNumber, err := tw.engine.LedgerNumberTracker.GetLedgerNumber() + if err != nil { + return fmt.Errorf("getting current ledger number: %w", err) + } + + if !txJob.Transaction.IsLocked(int32(currentLedgerNumber)) { + return fmt.Errorf("transaction should be locked") + } + + if !txJob.ChannelAccount.IsLocked(int32(currentLedgerNumber)) { + return fmt.Errorf("channel account should be locked") + } + + return nil +} + +func (tw *TransactionWorker) prepareForSubmission(ctx context.Context, txJob *TxJob) (*txnbuild.FeeBumpTransaction, error) { + feeBumpTx, err := tw.buildAndSignTransaction(ctx, txJob) + if err != nil { + return nil, fmt.Errorf("building transaction: %w", err) + } + + // Important: We need to save tx hash before submitting a transaction. + // If the script/server crashes after transaction is submitted but before the response + // is processed, we can easily determine whether tx was sent or not later using tx hash. + feeBumpTxHash, err := feeBumpTx.HashHex(tw.sigService.NetworkPassphrase()) + if err != nil { + return nil, fmt.Errorf("hashing transaction for job %v: %w", txJob, err) + } + + sentXDR, err := feeBumpTx.Base64() + if err != nil { + return nil, fmt.Errorf("getting envelopeXDR for job %v: %w", txJob, err) + } + + updatedTx, err := tw.txModel.UpdateStellarTransactionHashAndXDRSent(ctx, txJob.Transaction.ID, feeBumpTxHash, sentXDR) + if err != nil { + return nil, fmt.Errorf("saving transaction metadata for job %v: %w", txJob, err) + } + txJob.Transaction = *updatedTx + + return feeBumpTx, nil +} + +// buildAndSignTransaction builds & signs a Stellar payment transaction that is wrapped in a feebump transaction. +func (tw *TransactionWorker) buildAndSignTransaction(ctx context.Context, txJob *TxJob) (feeBumpTx *txnbuild.FeeBumpTransaction, err error) { + // validate the transaction asset + if txJob.Transaction.AssetCode == "" { + return nil, fmt.Errorf("asset code cannot be empty") + } + var asset txnbuild.Asset = txnbuild.NativeAsset{} + if strings.ToUpper(txJob.Transaction.AssetCode) != "XLM" { + if !strkey.IsValidEd25519PublicKey(txJob.Transaction.AssetIssuer) { + return nil, fmt.Errorf("invalid asset issuer: %v", txJob.Transaction.AssetIssuer) + } + asset = txnbuild.CreditAsset{ + Code: txJob.Transaction.AssetCode, + Issuer: txJob.Transaction.AssetIssuer, + } + } + + horizonAccount, err := tw.engine.HorizonClient.AccountDetail(horizonclient.AccountRequest{AccountID: txJob.ChannelAccount.PublicKey}) + if err != nil { + return nil, utils.NewHorizonErrorWrapper(err) + } + + // build the inner payment transaction + paymentTx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: txJob.ChannelAccount.PublicKey, + Sequence: horizonAccount.Sequence, + }, + Operations: []txnbuild.Operation{ + &txnbuild.Payment{ + SourceAccount: tw.sigService.DistributionAccount(), + Amount: strconv.FormatFloat(txJob.Transaction.Amount, 'f', 6, 32), // TODO find a better way to do this + Destination: txJob.Transaction.Destination, + Asset: asset, + }, + }, + BaseFee: int64(tw.maxBaseFee), + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(300), // maximum 5 minutes + LedgerBounds: &txnbuild.LedgerBounds{MaxLedger: uint32(txJob.LockedUntilLedgerNumber)}, // currently, 8-10 ledgers in the future + }, + IncrementSequenceNum: true, + }, + ) + if err != nil { + return nil, fmt.Errorf("building transaction for job %v: %w", txJob, err) + } + + paymentTx, err = tw.sigService.SignStellarTransaction(ctx, paymentTx, tw.sigService.DistributionAccount(), txJob.ChannelAccount.PublicKey) + if err != nil { + return nil, fmt.Errorf("signing transaction: for job %v: %w", txJob, err) + } + + // build the outer fee-bump transaction + feeBumpTx, err = txnbuild.NewFeeBumpTransaction( + txnbuild.FeeBumpTransactionParams{ + Inner: paymentTx, + FeeAccount: tw.sigService.DistributionAccount(), + BaseFee: int64(tw.maxBaseFee), + }, + ) + if err != nil { + return nil, fmt.Errorf("building fee-bump transaction for job %v: %w", txJob, err) + } + + // generate a random number to use as the fee-bump transaction's sequence number + feeBumpTx, err = tw.sigService.SignFeeBumpStellarTransaction(ctx, feeBumpTx, tw.sigService.DistributionAccount()) + if err != nil { + return nil, fmt.Errorf("signing fee-bump transaction for job %v: %w", txJob, err) + } + + return feeBumpTx, nil +} + +func (tw *TransactionWorker) submit(ctx context.Context, txJob *TxJob, feeBumpTx *txnbuild.FeeBumpTransaction) error { + resp, err := tw.engine.HorizonClient.SubmitFeeBumpTransactionWithOptions(feeBumpTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}) + if err != nil { + err = tw.handleFailedTransaction(ctx, txJob, resp, utils.NewHorizonErrorWrapper(err)) + if err != nil { + return fmt.Errorf("handling failed transaction: %w", err) + } + } else { + err = tw.handleSuccessfulTransaction(ctx, txJob, resp) + if err != nil { + return fmt.Errorf("handling successful transaction: %w", err) + } + } + + return nil +} + +func (tw *TransactionWorker) saveResponseXDRIfPresent(ctx context.Context, txJob *TxJob, resp horizon.Transaction) error { + if tssUtils.IsEmpty(resp) { + return nil + } + + resultXDR := resp.ResultXdr + updatedTx, err := tw.txModel.UpdateStellarTransactionXDRReceived(ctx, txJob.Transaction.ID, resultXDR) + if err != nil { + return fmt.Errorf("updating XDRReceived(%s) for job %v: %w", resultXDR, txJob, err) + } + txJob.Transaction = *updatedTx + + return nil +} + +// TODO: possibly use this code as a reference when addressing [SDP-772]. +// updateTransactionsMetric calculates and observes metrics for a given Transaction +// func (s *Submitter) updateTransactionsMetric(ctx context.Context, result, error_type string, tx *store.Transaction) { +// retried := "false" +// if tx.RetryCount > 0 { +// retried = "true" +// } +// labels := map[string]string{ +// "result": result, +// "error_type": error_type, +// "retried": retried, +// } +// // observe latency taken for transaction to complete +// err := s.MonitorService.MonitorHistogram(time.Since(*tx.CreatedAt).Seconds(), monitor.TransactionQueuedToCompletedLatencyTag, labels) +// if err != nil { +// log.Ctx(ctx).Errorf("error updating transaction metric counter: %s", err.Error()) +// } + +// err = s.MonitorService.MonitorHistogram(time.Since(*tx.StartedAt).Seconds(), monitor.TransactionStartedToCompletedLatencyTag, labels) +// if err != nil { +// log.Ctx(ctx).Errorf("error updating transaction metric counter: %s", err.Error()) +// } + +// err = s.MonitorService.MonitorHistogram(float64(tx.RetryCount), monitor.TransactionRetryCountTag, labels) +// if err != nil { +// log.Ctx(ctx).Errorf("error updating transaction metric counter: %s", err.Error()) +// } + +// err = s.MonitorService.MonitorCounters(monitor.TransactionProcessedCounterTag, labels) +// if err != nil { +// log.Ctx(ctx).Errorf("error updating transaction metric counter: %s", err.Error()) +// } +// } + +// // observeHorizonErrorMetric observes error metrics from horizon +// func (s *Submitter) observeHorizonErrorMetric(ctx context.Context, statusCode int, resultCode string) { +// labels := map[string]string{ +// "status_code": strconv.Itoa(statusCode), +// "result_code": resultCode, +// } +// err := s.MonitorService.MonitorCounters(monitor.HorizonErrorCounterTag, labels) +// if err != nil { +// log.Ctx(ctx).Errorf("error updating horizon error counter metric: %s", err.Error()) +// } +// } diff --git a/internal/transactionsubmission/transaction_worker_test.go b/internal/transactionsubmission/transaction_worker_test.go new file mode 100644 index 000000000..f1724bf96 --- /dev/null +++ b/internal/transactionsubmission/transaction_worker_test.go @@ -0,0 +1,901 @@ +package transactionsubmission + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/support/render/problem" + "github.com/stellar/go/txnbuild" + "github.com/stellar/stellar-disbursement-platform-backend/internal/crashtracker" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/internal/serve/httpclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine" + engineMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/engine/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store" + storeMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/store/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/transactionsubmission/utils" + sdpUtlis "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// getTransactionWorkerInstance is used to create a valid instance of the class TransactionWorker, which is needed in +// many tests in this file. +func getTransactionWorkerInstance(t *testing.T, dbConnectionPool db.DBConnectionPool) TransactionWorker { + t.Helper() + + txModel := store.NewTransactionModel(dbConnectionPool) + chAccModel := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + wantSubmitterEngine, err := engine.NewSubmitterEngine(&horizonclient.Client{ + HorizonURL: "https://horizon-testnet.stellar.org", + HTTP: httpclient.DefaultClient(), + }) + require.NoError(t, err) + + distributionKP := keypair.MustRandom() + wantSigService, err := engine.NewDefaultSignatureService( + network.TestNetworkPassphrase, + dbConnectionPool, + distributionKP.Seed(), + chAccModel, + &utils.PrivateKeyEncrypterMock{}, + distributionKP.Seed(), + ) + require.NoError(t, err) + + wantMaxBaseFee := 100 + + return TransactionWorker{ + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + sigService: wantSigService, + maxBaseFee: wantMaxBaseFee, + crashTrackerClient: &crashtracker.MockCrashTrackerClient{}, + } +} + +// createTxJobFixture is used to create the resoureces needed for a txJob, and return a txJob with these resources. It +// can be customized according with the parameters passed. +func createTxJobFixture(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool, shouldLock bool, currentLedger, lockedToLedger int) TxJob { + t.Helper() + var err error + + txModel := store.NewTransactionModel(dbConnectionPool) + chAccModel := store.NewChannelAccountModel(dbConnectionPool) + + // Create txJob: + tx := store.CreateTransactionFixture(t, ctx, dbConnectionPool, uuid.NewString(), "USDC", "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", "GCBIRB7Q5T53H4L6P5QSI3O6LPD5MBWGM5GHE7A5NY4XT5OT4VCOEZFX", store.TransactionStatusProcessing, 1) + chAcc := store.CreateChannelAccountFixtures(t, ctx, dbConnectionPool, 1)[0] + + if shouldLock { + tx, err = txModel.Lock(ctx, dbConnectionPool, tx.ID, int32(currentLedger), int32(lockedToLedger)) + require.NoError(t, err) + assert.True(t, tx.IsLocked(int32(currentLedger))) + + chAcc, err = chAccModel.Lock(ctx, dbConnectionPool, chAcc.PublicKey, int32(currentLedger), int32(lockedToLedger)) + require.NoError(t, err) + assert.True(t, chAcc.IsLocked(int32(currentLedger))) + } + + return TxJob{ChannelAccount: *chAcc, Transaction: *tx, LockedUntilLedgerNumber: lockedToLedger} +} + +func Test_NewTransactionWorker(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + txModel := store.NewTransactionModel(dbConnectionPool) + chAccModel := &store.ChannelAccountModel{DBConnectionPool: dbConnectionPool} + + wantSubmitterEngine, err := engine.NewSubmitterEngine(&horizonclient.Client{ + HorizonURL: "https://horizon-testnet.stellar.org", + HTTP: httpclient.DefaultClient(), + }) + require.NoError(t, err) + + distributionKP := keypair.MustRandom() + wantSigService, err := engine.NewDefaultSignatureService( + network.TestNetworkPassphrase, + dbConnectionPool, + distributionKP.Seed(), + chAccModel, + &utils.PrivateKeyEncrypterMock{}, + distributionKP.Seed(), + ) + require.NoError(t, err) + + wantMaxBaseFee := 100 + wantTxProcessingLimiter := engine.NewTransactionProcessingLimiter(20) + + wantWorker := TransactionWorker{ + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + sigService: wantSigService, + maxBaseFee: wantMaxBaseFee, + crashTrackerClient: &crashtracker.MockCrashTrackerClient{}, + txProcessingLimiter: wantTxProcessingLimiter, + } + + testCases := []struct { + name string + dbConnectionPool db.DBConnectionPool + txModel *store.TransactionModel + chAccModel *store.ChannelAccountModel + engine *engine.SubmitterEngine + sigService engine.SignatureService + maxBaseFee int + crashTrackerClient crashtracker.CrashTrackerClient + txProcessingLimiter *engine.TransactionProcessingLimiter + wantError error + }{ + { + name: "validate dbConnectionPool", + wantError: fmt.Errorf("dbConnectionPool cannot be nil"), + }, + { + name: "validate txModel", + dbConnectionPool: dbConnectionPool, + wantError: fmt.Errorf("txModel cannot be nil"), + }, + { + name: "validate chAccModel", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + wantError: fmt.Errorf("chAccModel cannot be nil"), + }, + { + name: "validate engine", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + wantError: fmt.Errorf("engine cannot be nil"), + }, + { + name: "validate sigService", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + wantError: fmt.Errorf("sigService cannot be nil"), + }, + { + name: "validate maxBaseFee", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + sigService: wantSigService, + wantError: fmt.Errorf("maxBaseFee must be greater than or equal to 100"), + }, + { + name: "validate crashTrackerClient", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + sigService: wantSigService, + maxBaseFee: wantMaxBaseFee, + wantError: fmt.Errorf("crashTrackerClient cannot be nil"), + }, + { + name: "validate txProcessingLimiter", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + sigService: wantSigService, + maxBaseFee: wantMaxBaseFee, + crashTrackerClient: &crashtracker.MockCrashTrackerClient{}, + wantError: fmt.Errorf("txProcessingLimiter cannot be nil"), + }, + { + name: "πŸŽ‰ successfully returns a new transaction worker", + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: wantSubmitterEngine, + sigService: wantSigService, + maxBaseFee: wantMaxBaseFee, + crashTrackerClient: &crashtracker.MockCrashTrackerClient{}, + txProcessingLimiter: wantTxProcessingLimiter, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotWorker, err := NewTransactionWorker( + tc.dbConnectionPool, + tc.txModel, + tc.chAccModel, + tc.engine, + tc.sigService, + tc.maxBaseFee, + tc.crashTrackerClient, + tc.txProcessingLimiter, + ) + + if tc.wantError != nil { + require.Error(t, err) + require.Equal(t, tc.wantError, err) + } else { + require.NoError(t, err) + require.NotEmpty(t, gotWorker) + require.Equal(t, wantWorker, gotWorker) + } + }) + } +} + +func Test_TransactionWorker_handleSuccessfulTransaction(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + currentLedger := 1 + lockedToLedger := 2 + + txModel := store.NewTransactionModel(dbConnectionPool) + chAccModel := store.NewChannelAccountModel(dbConnectionPool) + + t.Run("returns an error if UpdateStatusToSuccess fails", func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + transactionWorker := getTransactionWorkerInstance(t, dbConnectionPool) + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + require.NotEmpty(t, txJob) + + // mock UpdateStatusToSuccess FAIL + errReturned := fmt.Errorf("updating transaction status to TransactionStatusSuccess: foo") + mockTxStore := &storeMocks.MockTransactionStore{} + mockTxStore. + On("UpdateStatusToSuccess", ctx, mock.AnythingOfType("store.Transaction")). + Return(nil, errReturned). + Once() + mockTxStore. + On("UpdateStellarTransactionXDRReceived", ctx, mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return(&txJob.Transaction, nil). + Once() + transactionWorker.txModel = mockTxStore + + // Run test: + err := transactionWorker.handleSuccessfulTransaction(ctx, &txJob, horizon.Transaction{Successful: true}) + require.Error(t, err) + wantErr := utils.NewTransactionStatusUpdateError("SUCCESS", txJob.Transaction.ID, false, errReturned) + require.Equal(t, wantErr, err) + + mockTxStore.AssertExpectations(t) + }) + + t.Run("returns an error if ChannelAccountModel.Unlock fails", func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + transactionWorker := getTransactionWorkerInstance(t, dbConnectionPool) + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + require.NotEmpty(t, txJob) + + // mock UpdateStatusToSuccess βœ… + mockTxStore := &storeMocks.MockTransactionStore{} + mockTxStore. + On("UpdateStatusToSuccess", ctx, mock.AnythingOfType("store.Transaction")). + Return(&store.Transaction{}, nil). + Once() + mockTxStore. + On("UpdateStellarTransactionXDRReceived", ctx, mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return(&txJob.Transaction, nil). + Once() + transactionWorker.txModel = mockTxStore + + // mock channelAccount Unlock (FAIL) + errReturned := fmt.Errorf("something went wrong") + mockChAccStore := &storeMocks.MockChannelAccountStore{} + mockChAccStore. + On("Unlock", ctx, dbConnectionPool, mock.AnythingOfType("string")). + Return(nil, errReturned). + Once() + transactionWorker.chAccModel = mockChAccStore + + // Run test: + err := transactionWorker.handleSuccessfulTransaction(ctx, &txJob, horizon.Transaction{Successful: true}) + require.Error(t, err) + wantErr := fmt.Errorf("unlocking job: %w", fmt.Errorf("unlocking channel account: %w", errReturned)) + require.Equal(t, wantErr, err) + + mockTxStore.AssertExpectations(t) + }) + + t.Run("returns an error TransactionModel.Unlock fails", func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + transactionWorker := getTransactionWorkerInstance(t, dbConnectionPool) + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + require.NotEmpty(t, txJob) + + // mock UpdateStatusToSuccess βœ… + mockTxStore := &storeMocks.MockTransactionStore{} + mockTxStore. + On("UpdateStellarTransactionXDRReceived", ctx, mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return(&txJob.Transaction, nil). + Once() + mockTxStore. + On("UpdateStatusToSuccess", ctx, mock.AnythingOfType("store.Transaction")). + Return(&store.Transaction{}, nil). + Once() + + // mock channelAccount Unlock βœ… + mockChAccStore := &storeMocks.MockChannelAccountStore{} + mockChAccStore. + On("Unlock", ctx, dbConnectionPool, mock.AnythingOfType("string")). + Return(&store.ChannelAccount{}, nil). + Once() + transactionWorker.chAccModel = mockChAccStore + + // mock TransactionModel.Unlock (FAIL) + errReturned := fmt.Errorf("something went wrong") + mockTxStore. + On("Unlock", ctx, dbConnectionPool, mock.AnythingOfType("string")). + Return(nil, errReturned). + Once() + transactionWorker.txModel = mockTxStore + + // Run test: + err := transactionWorker.handleSuccessfulTransaction(ctx, &txJob, horizon.Transaction{Successful: true}) + require.Error(t, err) + wantErr := fmt.Errorf("unlocking job: %w", fmt.Errorf("unlocking transaction: %w", errReturned)) + require.Equal(t, wantErr, err) + + mockTxStore.AssertExpectations(t) + }) + + t.Run("πŸŽ‰ successfully handles a transaction success", func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + transactionWorker := getTransactionWorkerInstance(t, dbConnectionPool) + require.NotEmpty(t, transactionWorker) + + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + require.NotEmpty(t, txJob) + + // Run test: + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + err := transactionWorker.handleSuccessfulTransaction(ctx, &txJob, horizon.Transaction{Successful: true, ResultXdr: resultXDR}) + require.NoError(t, err) + + // Assert the final state of the transaction in the DB: + tx, err := txModel.Get(ctx, txJob.Transaction.ID) + require.NoError(t, err) + assert.Equal(t, store.TransactionStatusSuccess, tx.Status) + assert.Equal(t, resultXDR, tx.XDRReceived.String) + assert.False(t, tx.IsLocked(int32(currentLedger))) + + // Assert the final state of the channel account in the DB: + chAcc, err := chAccModel.Get(ctx, dbConnectionPool, txJob.ChannelAccount.PublicKey, 0) + require.NoError(t, err) + assert.False(t, chAcc.IsLocked(int32(currentLedger))) + }) + + t.Run("if a transaction with successful=false is passed, we save the xdr and leave it to be checked on reconciliation", func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + transactionWorker := getTransactionWorkerInstance(t, dbConnectionPool) + require.NotEmpty(t, transactionWorker) + + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + require.NotEmpty(t, txJob) + + // Run test: + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + err := transactionWorker.handleSuccessfulTransaction(ctx, &txJob, horizon.Transaction{Successful: false, ResultXdr: resultXDR}) + require.EqualError(t, err, "transaction was not successful for some reason") + + // Assert the final state of the transaction in the DB: + tx, err := txModel.Get(ctx, txJob.Transaction.ID) + require.NoError(t, err) + assert.Equal(t, store.TransactionStatusProcessing, tx.Status) + assert.Equal(t, resultXDR, tx.XDRReceived.String) + assert.True(t, tx.IsLocked(int32(currentLedger))) + + // Assert the final state of the channel account in the DB: + chAcc, err := chAccModel.Get(ctx, dbConnectionPool, txJob.ChannelAccount.PublicKey, 0) + require.NoError(t, err) + assert.True(t, chAcc.IsLocked(int32(currentLedger))) + }) +} + +func Test_TransactionWorker_reconcileSubmittedTransaction(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + const currentLedger = 1 + const lockedToLedger = 2 + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + + transactionWorker := getTransactionWorkerInstance(t, dbConnectionPool) + require.NotEmpty(t, transactionWorker) + + testCases := []struct { + name string + horizonTxResponse horizon.Transaction + horizonTxError error + wantErrContains string + shouldBeMarkedAsSuccessful bool + shouldBePushedBackToQueue bool + }{ + { + name: "πŸŽ‰ successfully verifies the tx went through and marks it as successful", + horizonTxResponse: horizon.Transaction{Successful: true, ResultXdr: resultXDR}, + }, + { + name: "πŸŽ‰ successfully verifies the tx failed and mark it for resubmission", + horizonTxResponse: horizon.Transaction{Successful: false}, + shouldBePushedBackToQueue: true, + }, + { + name: "πŸŽ‰ check the transaction returns a 404, so we mark it for resubmission", + horizonTxError: horizonclient.Error{Problem: problem.P{Status: http.StatusNotFound}}, + shouldBePushedBackToQueue: true, + }, + { + name: "un unexpected error is returned, so we wrap and send to the caller", + horizonTxError: horizonclient.Error{Problem: problem.P{Status: http.StatusTooManyRequests}}, + shouldBePushedBackToQueue: false, + wantErrContains: "unexpected error: horizon response error: StatusCode=429", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + const txHash = "3389e9f0f1a65f19736cacf544c2e825313e8447f569233bb8db39aa607c8889" + const envelopeXDR = "AAAAAGL8HQvQkbK2HA3WVjRrKmjX00fG8sLI7m0ERwJW/AX3AAAACgAAAAAAAAABAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAArqN6LeOagjxMaUP96Bzfs9e0corNZXzBWJkFoK7kvkwAAAAAO5rKAAAAAAAAAAABVvwF9wAAAEAKZ7IPj/46PuWU6ZOtyMosctNAkXRNX9WCAI5RnfRk+AyxDLoDZP/9l3NvsxQtWj9juQOuoBlFLnWu8intgxQA" + + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + tx, err := transactionWorker.txModel.UpdateStellarTransactionHashAndXDRSent(ctx, txJob.Transaction.ID, txHash, envelopeXDR) + require.NoError(t, err) + txJob.Transaction = *tx + + // mock LedgerNumberTracker + mockLedgerNumberTracker := &engineMocks.MockLedgerNumberTracker{} + mockLedgerNumberTracker.On("GetLedgerNumber").Return(currentLedger, nil).Once() + transactionWorker.engine.LedgerNumberTracker = mockLedgerNumberTracker + + // mock TransactionDetail + hMock := &horizonclient.MockClient{} + hMock.On("TransactionDetail", txHash).Return(tc.horizonTxResponse, tc.horizonTxError).Once() + transactionWorker.engine.HorizonClient = hMock + + // Run test: + err = transactionWorker.reconcileSubmittedTransaction(ctx, &txJob) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + require.NoError(t, err) + } + + if tc.shouldBeMarkedAsSuccessful { + // Assert the final state of the transaction in the DB: + tx, err := transactionWorker.txModel.Get(ctx, txJob.Transaction.ID) + require.NoError(t, err) + assert.Equal(t, store.TransactionStatusSuccess, tx.Status) + assert.False(t, tx.IsLocked(int32(currentLedger))) + + // Assert the final state of the channel account in the DB: + chAcc, err := transactionWorker.chAccModel.Get(ctx, dbConnectionPool, txJob.ChannelAccount.PublicKey, 0) + require.NoError(t, err) + assert.False(t, chAcc.IsLocked(int32(currentLedger))) + } + + if tc.shouldBePushedBackToQueue { + // Assert the final state of the transaction in the DB: + tx, err := transactionWorker.txModel.Get(ctx, txJob.Transaction.ID) + require.NoError(t, err) + assert.Equal(t, store.TransactionStatusProcessing, tx.Status) + assert.False(t, tx.IsLocked(int32(currentLedger))) + + // Assert the final state of the channel account in the DB: + chAcc, err := transactionWorker.chAccModel.Get(ctx, dbConnectionPool, txJob.ChannelAccount.PublicKey, 0) + require.NoError(t, err) + assert.False(t, chAcc.IsLocked(int32(currentLedger))) + } + + mockLedgerNumberTracker.AssertExpectations(t) + hMock.AssertExpectations(t) + }) + } +} + +func Test_TransactionWorker_validateJob(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + const currentLedger int32 = 1 + const lockedToLedger int32 = 2 + + testCases := []struct { + name string + initialTransactionStatus store.TransactionStatus + wantHorizonErrorStatusCode int + shouldLockTx bool + shouldLockChAcc bool + wantErrContains string + }{ + { + name: "returns an error if the initial transaction status is SUCCESS", + initialTransactionStatus: store.TransactionStatusSuccess, + wantErrContains: "invalid transaction status: SUCCESS", + }, + { + name: "returns an error if the initial transaction status is ERROR", + initialTransactionStatus: store.TransactionStatusError, + wantErrContains: "invalid transaction status: ERROR", + }, + { + name: "returns an error if horizon returns an error", + wantHorizonErrorStatusCode: http.StatusBadGateway, + initialTransactionStatus: store.TransactionStatusProcessing, + wantErrContains: "getting current ledger number: ", + }, + { + name: "returns an error if job's tx is not locked", + wantHorizonErrorStatusCode: http.StatusOK, + initialTransactionStatus: store.TransactionStatusProcessing, + wantErrContains: "transaction should be locked", + }, + { + name: "returns an error if job's channel account is not locked", + wantHorizonErrorStatusCode: http.StatusOK, + initialTransactionStatus: store.TransactionStatusProcessing, + shouldLockTx: true, + wantErrContains: "channel account should be locked", + }, + { + name: "πŸŽ‰ successfully validate job when the resources are locked, horizon works and status is supported (PROCESSING)", + wantHorizonErrorStatusCode: http.StatusOK, + initialTransactionStatus: store.TransactionStatusProcessing, + shouldLockTx: true, + shouldLockChAcc: true, + }, + { + name: "πŸŽ‰ successfully validate job when the resources are locked, horizon works and status is supported (PENDING)", + wantHorizonErrorStatusCode: http.StatusOK, + initialTransactionStatus: store.TransactionStatusProcessing, + shouldLockTx: true, + shouldLockChAcc: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + hMock := &horizonclient.MockClient{} + if tc.wantHorizonErrorStatusCode == http.StatusOK { + hMock.On("Root").Return(horizon.Root{HorizonSequence: int32(currentLedger)}, nil).Once() + } else if tc.wantHorizonErrorStatusCode != 0 { + hMock.On("Root").Return(horizon.Root{}, horizonclient.Error{Problem: problem.P{Status: http.StatusBadGateway}}).Once() + } + + // Create a transaction worker: + submitterEngine, err := engine.NewSubmitterEngine(hMock) + require.NoError(t, err) + transactionWorker := &TransactionWorker{ + engine: submitterEngine, + txModel: store.NewTransactionModel(dbConnectionPool), + chAccModel: store.NewChannelAccountModel(dbConnectionPool), + } + + // create txJob: + txJob := createTxJobFixture(t, ctx, dbConnectionPool, false, int(currentLedger), int(lockedToLedger)) + + // Update status for txJob.Transaction + var updatedTx store.Transaction + q := `UPDATE submitter_transactions SET status = $1 WHERE id = $2 RETURNING *` + err = dbConnectionPool.GetContext(ctx, &updatedTx, q, tc.initialTransactionStatus, txJob.Transaction.ID) + require.NoError(t, err) + txJob.Transaction = updatedTx + + // Lock txJob Channel account and transaction: + if tc.shouldLockTx { + lockedTx, innerErr := transactionWorker.txModel.Lock(ctx, dbConnectionPool, txJob.Transaction.ID, currentLedger, lockedToLedger) + require.NoError(t, innerErr) + txJob.Transaction = *lockedTx + } + if tc.shouldLockChAcc { + lockedChAcc, innerErr := transactionWorker.chAccModel.Lock(ctx, dbConnectionPool, txJob.ChannelAccount.PublicKey, currentLedger, lockedToLedger) + require.NoError(t, innerErr) + txJob.ChannelAccount = *lockedChAcc + } + + // Run test: + err = transactionWorker.validateJob(&txJob) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrContains) + } else { + require.NoError(t, err) + } + + hMock.AssertExpectations(t) + }) + } +} + +func Test_TransactionWorker_buildAndSignTransaction(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + const currentLedger = 1 + const lockedToLedger = 2 + const accountSequence = 123 + + distributionKP := keypair.MustRandom() + sigService, err := engine.NewDefaultSignatureService( + network.TestNetworkPassphrase, + dbConnectionPool, + distributionKP.Seed(), + store.NewChannelAccountModel(dbConnectionPool), + &utils.PrivateKeyEncrypterMock{}, + distributionKP.Seed(), + ) + require.NoError(t, err) + + testCases := []struct { + name string + assetCode string + assetIssuer string + getAccountResponseObj horizon.Account + getAccountResponseError *horizonclient.Error + wantErrorContains string + }{ + { + name: "returns an error if the asset code is empty", + wantErrorContains: "asset code cannot be empty", + }, + { + name: "returns an error if the asset code is not XLM and the issuer is not valid", + assetCode: "USDC", + assetIssuer: "FOOBAR", + wantErrorContains: "invalid asset issuer: FOOBAR", + }, + { + name: "return an error if the AccountDetail call fails", + assetCode: "USDC", + assetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + getAccountResponseObj: horizon.Account{}, + getAccountResponseError: &horizonclient.Error{Problem: problem.P{Status: http.StatusTooManyRequests}}, + wantErrorContains: "horizon response error: ", + }, + { + name: "πŸŽ‰ successfully build and sign a transaction", + assetCode: "USDC", + assetIssuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + getAccountResponseObj: horizon.Account{Sequence: accountSequence}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, currentLedger, lockedToLedger) + txJob.Transaction.AssetCode = tc.assetCode + txJob.Transaction.AssetIssuer = tc.assetIssuer + + // mock horizon + mockHorizon := &horizonclient.MockClient{} + if !sdpUtlis.IsEmpty(tc.getAccountResponseObj) || !sdpUtlis.IsEmpty(tc.getAccountResponseError) { + var hErr error + if tc.getAccountResponseError != nil { + hErr = tc.getAccountResponseError + } + mockHorizon.On("AccountDetail", horizonclient.AccountRequest{AccountID: txJob.ChannelAccount.PublicKey}).Return(tc.getAccountResponseObj, hErr).Once() + } + mockStore := &storeMocks.MockChannelAccountStore{} + mockStore.On("Get", ctx, mock.Anything, txJob.ChannelAccount.PublicKey, 0).Return(txJob.ChannelAccount, nil) + + // Create a transaction worker: + submitterEngine := &engine.SubmitterEngine{HorizonClient: mockHorizon} + transactionWorker := &TransactionWorker{ + engine: submitterEngine, + txModel: store.NewTransactionModel(dbConnectionPool), + chAccModel: store.NewChannelAccountModel(dbConnectionPool), + sigService: sigService, + maxBaseFee: 100, + } + + // Run test: + gotFeeBumpTx, err := transactionWorker.buildAndSignTransaction(context.Background(), &txJob) + if tc.wantErrorContains != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tc.wantErrorContains) + assert.Nil(t, gotFeeBumpTx) + } else { + require.NoError(t, err) + require.NotNil(t, gotFeeBumpTx) + + // Check that the transaction was built correctly: + var wantAsset txnbuild.Asset = txnbuild.NativeAsset{} + if strings.ToUpper(txJob.Transaction.AssetCode) != "XLM" { + wantAsset = txnbuild.CreditAsset{ + Code: txJob.Transaction.AssetCode, + Issuer: txJob.Transaction.AssetIssuer, + } + } + wantInnerTx, err := txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: txJob.ChannelAccount.PublicKey, + Sequence: accountSequence, + }, + Operations: []txnbuild.Operation{ + &txnbuild.Payment{ + SourceAccount: distributionKP.Address(), + Amount: strconv.FormatFloat(txJob.Transaction.Amount, 'f', 6, 32), // TODO find a better way to do this + Destination: txJob.Transaction.Destination, + Asset: wantAsset, + }, + }, + BaseFee: int64(transactionWorker.maxBaseFee), + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(300), + LedgerBounds: &txnbuild.LedgerBounds{MaxLedger: uint32(txJob.LockedUntilLedgerNumber)}, + }, + IncrementSequenceNum: true, + }, + ) + require.NoError(t, err) + wantInnerTx, err = sigService.SignStellarTransaction(ctx, wantInnerTx, distributionKP.Address(), txJob.ChannelAccount.PublicKey) + require.NoError(t, err) + + wantFeeBumpTx, err := txnbuild.NewFeeBumpTransaction( + txnbuild.FeeBumpTransactionParams{ + Inner: wantInnerTx, + FeeAccount: distributionKP.Address(), + BaseFee: int64(transactionWorker.maxBaseFee), + }, + ) + require.NoError(t, err) + wantFeeBumpTx, err = sigService.SignFeeBumpStellarTransaction(ctx, wantFeeBumpTx, distributionKP.Address()) + require.NoError(t, err) + assert.Equal(t, wantFeeBumpTx, gotFeeBumpTx) + } + + mockHorizon.AssertExpectations(t) + }) + } +} + +func Test_TransactionWorker_submit(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + txModel := store.NewTransactionModel(dbConnectionPool) + chAccModel := store.NewChannelAccountModel(dbConnectionPool) + const resultXDR = "AAAAAAAAAGQAAAAAAAAAAQAAAAAAAAAOAAAAAAAAAABw2JZZYIt4n/WXKcnDow3mbTBMPrOnldetgvGUlpTSEQAAAAA=" + + testCases := []struct { + name string + horizonResponse horizon.Transaction + horizonError error + wantFinalTransactionStatus store.TransactionStatus + wantFinalResultXDR string + }{ + { + name: "unrecoverable horizon error is handled and tx status is marked as ERROR", + horizonResponse: horizon.Transaction{}, + horizonError: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_no_trust"}, // <--- this should make the transaction be marked as ERROR + }, + }, + }, + }, + wantFinalTransactionStatus: store.TransactionStatusError, + }, + { + name: "successful horizon error is handled and tx status is marked as SUCCESS", + horizonResponse: horizon.Transaction{Successful: true, ResultXdr: resultXDR}, + horizonError: nil, + wantFinalTransactionStatus: store.TransactionStatusSuccess, + wantFinalResultXDR: resultXDR, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer store.DeleteAllFromChannelAccounts(t, ctx, dbConnectionPool) + defer store.DeleteAllTransactionFixtures(t, ctx, dbConnectionPool) + + txJob := createTxJobFixture(t, ctx, dbConnectionPool, true, 1, 2) + feeBumpTx := &txnbuild.FeeBumpTransaction{} + + mockHorizonClient := &horizonclient.MockClient{} + txProcessingLimiter := engine.NewTransactionProcessingLimiter(15) + mockHorizonClient. + On("SubmitFeeBumpTransactionWithOptions", feeBumpTx, horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). + Return(tc.horizonResponse, tc.horizonError). + Once() + transactionWorker := TransactionWorker{ + dbConnectionPool: dbConnectionPool, + txModel: txModel, + chAccModel: chAccModel, + engine: &engine.SubmitterEngine{ + HorizonClient: mockHorizonClient, + }, + txProcessingLimiter: txProcessingLimiter, + } + + // make sure the tx's initial status is PROCESSING: + refreshedTx, err := txModel.Get(ctx, txJob.Transaction.ID) + require.NoError(t, err) + require.Equal(t, store.TransactionStatusProcessing, refreshedTx.Status) + assert.Equal(t, *refreshedTx, txJob.Transaction) + + err = transactionWorker.submit(ctx, &txJob, feeBumpTx) + require.NoError(t, err) + + // make sure the tx's status is the expected one: + refreshedTx, err = txModel.Get(ctx, txJob.Transaction.ID) + require.NoError(t, err) + require.Equal(t, tc.wantFinalTransactionStatus, refreshedTx.Status) + assert.Equal(t, tc.wantFinalResultXDR, refreshedTx.XDRReceived.String) + + // check if the channel account was unlocked: + refreshedChAcc, err := chAccModel.Get(ctx, dbConnectionPool, txJob.ChannelAccount.PublicKey, 0) + require.NoError(t, err) + assert.False(t, refreshedChAcc.IsLocked(int32(txJob.LockedUntilLedgerNumber))) + + mockHorizonClient.AssertExpectations(t) + }) + } +} diff --git a/internal/transactionsubmission/utils/errors.go b/internal/transactionsubmission/utils/errors.go new file mode 100644 index 000000000..b8eb9db05 --- /dev/null +++ b/internal/transactionsubmission/utils/errors.go @@ -0,0 +1,305 @@ +package utils + +import ( + "fmt" + "net/http" + "strings" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/support/log" + "github.com/stellar/go/support/render/problem" + sdpUtils "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" + "golang.org/x/exp/slices" +) + +// TransactionStatusUpdateError is an error that occurs when failing to update a transaction's status. +type TransactionStatusUpdateError struct { + Status string + TxID string + ForRetry bool + // Err is the underlying error that caused the transaction status update to fail. + Err error +} + +func (e *TransactionStatusUpdateError) Error() string { + forRetry := "" + if e.ForRetry { + forRetry = " (for retry)" + } + return fmt.Sprintf("updating transaction(ID=%q) status to %s%s: %v", e.TxID, e.Status, forRetry, e.Err) +} + +func (e *TransactionStatusUpdateError) Unwrap() error { + return e.Err +} + +func NewTransactionStatusUpdateError(status, txID string, forRetry bool, err error) *TransactionStatusUpdateError { + return &TransactionStatusUpdateError{ + Status: status, + TxID: txID, + ForRetry: forRetry, + Err: err, + } +} + +var _ error = &TransactionStatusUpdateError{} + +// HorizonErrorWrapper is an error that occurs when a horizon response is not successful. +type HorizonErrorWrapper struct { + StatusCode int + Problem problem.P + Err error + ResultCodes *horizon.TransactionResultCodes +} + +func NewHorizonErrorWrapper(err error) *HorizonErrorWrapper { + if err == nil { + return nil + } + + hError := horizonclient.GetError(err) + if hError == nil { + return &HorizonErrorWrapper{ + Err: err, + } + } + + resultCodes, resCodeErr := hError.ResultCodes() + if resCodeErr != nil { + log.Errorf("parsing result_codes: %v", resCodeErr) + } + + return &HorizonErrorWrapper{ + Err: err, + Problem: hError.Problem, + StatusCode: hError.Problem.Status, + ResultCodes: resultCodes, + } +} + +func (e *HorizonErrorWrapper) Unwrap() error { + return e.Err +} + +func (e *HorizonErrorWrapper) Error() string { + if !e.IsHorizonError() { + return fmt.Sprintf("horizon response error: %v", e.Err) + } + + msgBuilder := &strings.Builder{} + msgBuilder.WriteString(fmt.Sprintf("horizon response error: StatusCode=%d", e.StatusCode)) + if e.Problem.Type != "" { + msgBuilder.WriteString(fmt.Sprintf(", Type=%s", e.Problem.Type)) + } + if e.Problem.Title != "" { + msgBuilder.WriteString(fmt.Sprintf(", Title=%s", e.Problem.Title)) + } + if e.Problem.Detail != "" { + msgBuilder.WriteString(fmt.Sprintf(", Detail=%s", e.Problem.Detail)) + } + // TODO: place extras right after status codes, for better readability. Details are pretty verbose and not that useful. + if e.HasResultCodes() { + e.handleExtrasResultCodes(msgBuilder) + } + return msgBuilder.String() +} + +func (e *HorizonErrorWrapper) IsHorizonError() bool { + return !sdpUtils.IsEmpty(e.Problem) +} + +func (e *HorizonErrorWrapper) IsNotFound() bool { + return e.IsHorizonError() && e.StatusCode == http.StatusNotFound +} + +func (e *HorizonErrorWrapper) IsRateLimit() bool { + return e.IsHorizonError() && e.StatusCode == http.StatusTooManyRequests +} + +func (e *HorizonErrorWrapper) IsGatewayTimeout() bool { + return e.IsHorizonError() && e.StatusCode == http.StatusGatewayTimeout +} + +func (e *HorizonErrorWrapper) HasResultCodes() bool { + return e.IsHorizonError() && e.ResultCodes != nil +} + +// IsNotEnoughLumens verifies if the Horizon Error is related to the +// transaction attempting to bring the source account lumens balance below the minimum reserve. +func (e *HorizonErrorWrapper) IsNotEnoughLumens() bool { + if !e.HasResultCodes() { + return false + } + + code := "tx_insufficient_balance" + opCode := "op_underfunded" + return (e.ResultCodes.TransactionCode == code || + e.ResultCodes.InnerTransactionCode == code || + slices.Contains(e.ResultCodes.OperationCodes, opCode)) +} + +// IsNoSourceAccount verifies if the Horizon Error is related to the +// source account not being found. +func (e *HorizonErrorWrapper) IsNoSourceAccount() bool { + if !e.HasResultCodes() { + return false + } + + txCode := "tx_no_source_account" + opCode := "op_no_source_account" + return (e.ResultCodes.TransactionCode == txCode || + e.ResultCodes.InnerTransactionCode == txCode || + slices.Contains(e.ResultCodes.OperationCodes, opCode)) +} + +// IsNoIssuer verifies if the Horizon Error is related to the +// issuer of the asset not existing. +func (e *HorizonErrorWrapper) IsNoIssuer() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_no_issuer" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsSourceNotAuthorized verifies if the Horizon Error is related to the +// source account not having authorization from the asset issuer to send the asset. +func (e *HorizonErrorWrapper) IsSourceAccountNotAuthorized() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_src_not_authorized" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsSourceNoTrustline verifies if the Horizon Error is related to the +// source account not having a trustline for the asset being sent. +func (e *HorizonErrorWrapper) IsSourceNoTrustline() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_src_no_trust" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsDestinationAccountNotAuthorized verifies if the Horizon Error is related to the +// destination account is not being authorized by the asset issuer to receive the asset. +func (e *HorizonErrorWrapper) IsDestinationAccountNotAuthorized() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_not_authorized" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsNoTrustline verifies if the Horizon Error is related to the +// destination account not having a trustline for the asset being sent. +func (e *HorizonErrorWrapper) IsDestinationNoTrustline() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_no_trust" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsLineFull verifies if the Horizon Error is related to the +// destination account not having sufficient limits to receive the payment amount +// and still satisfy its buying liabilities. +func (e *HorizonErrorWrapper) IsLineFull() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_line_full" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsNoDestinationAccount verifies if the Horizon Error is related to the +// destination account not existing. +func (e *HorizonErrorWrapper) IsNoDestinationAccount() bool { + if !e.HasResultCodes() { + return false + } + + opCode := "op_no_destination" + return slices.Contains(e.ResultCodes.OperationCodes, opCode) +} + +// IsBadAuthentication verifies if the Horizon Error is related to +// invalid transaction or operation signatures. +func (e *HorizonErrorWrapper) IsBadAuthentication() bool { + if !e.HasResultCodes() { + return false + } + + txCodes := []string{"tx_bad_auth", "tx_bad_auth_extra"} + opCode := "op_bad_auth" + return (slices.Contains(txCodes, e.ResultCodes.TransactionCode) || + slices.Contains(txCodes, e.ResultCodes.InnerTransactionCode) || + slices.Contains(e.ResultCodes.OperationCodes, opCode)) +} + +// IsTxInsufficientFee verifies if the Horizon Error is related to the +// fee submitted being too small to be accepted by to the ledger by +// the network. +func (e *HorizonErrorWrapper) IsTxInsufficientFee() bool { + if !e.HasResultCodes() { + return false + } + + txCode := "tx_insufficient_fee" + return e.ResultCodes.TransactionCode == txCode +} + +// IsSourceAccountNotReady verifies if the Horizon Error is related to the +// source account of the transaction. It gathers all errors that would happen +// in a transaction because of a misconfiguration of the source account. +func (e *HorizonErrorWrapper) IsSourceAccountNotReady() bool { + return (e.IsNotEnoughLumens() || + e.IsNoSourceAccount() || + e.IsSourceAccountNotAuthorized() || + e.IsSourceNoTrustline()) +} + +// IsDestinationAccountNotReady verifies if the Horizon Error is related to the +// destination account of the transaction. It gathers all errors that would happen +// in a transaction because of a misconfiguration of the destination account. +func (e *HorizonErrorWrapper) IsDestinationAccountNotReady() bool { + return (e.IsDestinationAccountNotAuthorized() || + e.IsDestinationNoTrustline() || + e.IsNoDestinationAccount() || + e.IsLineFull()) +} + +func (e *HorizonErrorWrapper) handleExtrasResultCodes(msgBuilder *strings.Builder) { + if !e.HasResultCodes() { + return + } + + extras := []string{} + if e.ResultCodes.TransactionCode != "" { + extras = append(extras, fmt.Sprintf("transaction: %s", e.ResultCodes.TransactionCode)) + } + + if e.ResultCodes.InnerTransactionCode != "" { + extras = append(extras, fmt.Sprintf("inner transaction: %s", e.ResultCodes.InnerTransactionCode)) + } + + if len(e.ResultCodes.OperationCodes) > 0 { + msg := fmt.Sprintf("operation codes: [ %s ]", strings.Join(e.ResultCodes.OperationCodes, ", ")) + extras = append(extras, msg) + } + + if len(extras) > 0 { + msgBuilder.WriteString(", Extras=") + msgBuilder.WriteString(strings.Join(extras, " - ")) + } +} + +var _ error = &HorizonErrorWrapper{} diff --git a/internal/transactionsubmission/utils/errors_test.go b/internal/transactionsubmission/utils/errors_test.go new file mode 100644 index 000000000..5cbb85db0 --- /dev/null +++ b/internal/transactionsubmission/utils/errors_test.go @@ -0,0 +1,1517 @@ +package utils + +import ( + "errors" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/protocols/horizon" + "github.com/stellar/go/support/render/problem" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_NewTransactionStatusUpdateError(t *testing.T) { + status := "ERROR" + txID := "some-tx-id" + forRetry := false + err := fmt.Errorf("some error") + txStatusUpdateErr := NewTransactionStatusUpdateError(status, txID, forRetry, err) + + wantTxStatusUpdateErr := &TransactionStatusUpdateError{ + Status: status, + TxID: txID, + ForRetry: forRetry, + Err: err, + } + require.Equal(t, wantTxStatusUpdateErr, txStatusUpdateErr) +} + +func Test_TransactionStatusUpdateError_Error(t *testing.T) { + testCases := []struct { + name string + status string + txID string + forRetry bool + err error + wantStringResult string + }{ + { + name: "PENDING for retry", + status: "PENDING", + txID: "foo", + forRetry: true, + err: fmt.Errorf("some causing error"), + wantStringResult: "updating transaction(ID=\"foo\") status to PENDING (for retry): some causing error", + }, + { + name: "ERROR without retry", + status: "ERROR", + txID: "bar", + forRetry: false, + err: fmt.Errorf("another causing error"), + wantStringResult: "updating transaction(ID=\"bar\") status to ERROR: another causing error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + txStatusUpdateErr := NewTransactionStatusUpdateError(tc.status, tc.txID, tc.forRetry, tc.err) + require.Equal(t, tc.wantStringResult, txStatusUpdateErr.Error()) + }) + } +} + +func Test_TransactionStatusUpdateError_Unwrap_and_Is(t *testing.T) { + err := fmt.Errorf("some causing error") + txStatusUpdateErr := NewTransactionStatusUpdateError("ERROR", "some-tx-id", false, err) + require.Equal(t, err, txStatusUpdateErr.Unwrap()) + require.True(t, errors.Is(txStatusUpdateErr, err)) +} + +func Test_TransactionStatusUpdateError_As(t *testing.T) { + err := fmt.Errorf("some causing error") + var someError error = NewTransactionStatusUpdateError("ERROR", "some-tx-id", false, err) + + var txStatusUpdateErr *TransactionStatusUpdateError + require.True(t, errors.As(someError, &txStatusUpdateErr)) + + err = fmt.Errorf("sandwich the error: %w", txStatusUpdateErr) + require.True(t, errors.As(err, &txStatusUpdateErr)) +} + +func Test_NewHorizonErrorWrapper(t *testing.T) { + hError := horizonclient.Error{ + Problem: problem.P{ + Title: "Transaction Failed", + Type: "transaction_failed", + Status: http.StatusBadRequest, + Detail: "", + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_underfunded"}, + }, + }, + }, + } + + testCases := []struct { + name string + originalErr error + wantHorizonResponseErr *HorizonErrorWrapper + }{ + { + name: "nil error", + originalErr: nil, + wantHorizonResponseErr: nil, + }, + { + name: "non-horizon error", + originalErr: fmt.Errorf("some error"), + wantHorizonResponseErr: &HorizonErrorWrapper{Err: fmt.Errorf("some error")}, + }, + { + name: "horizon error (value)", + originalErr: hError, + wantHorizonResponseErr: &HorizonErrorWrapper{ + StatusCode: http.StatusBadRequest, + Problem: hError.Problem, + Err: hError, + ResultCodes: &horizon.TransactionResultCodes{ + TransactionCode: "tx_failed", + InnerTransactionCode: "", + OperationCodes: []string{"op_underfunded"}, + }, + }, + }, + { + name: "horizon error (pointer)", + originalErr: &hError, + wantHorizonResponseErr: &HorizonErrorWrapper{ + StatusCode: http.StatusBadRequest, + Problem: hError.Problem, + Err: &hError, + ResultCodes: &horizon.TransactionResultCodes{ + TransactionCode: "tx_failed", + InnerTransactionCode: "", + OperationCodes: []string{"op_underfunded"}, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + horizonResponseErr := NewHorizonErrorWrapper(tc.originalErr) + require.Equal(t, tc.wantHorizonResponseErr, horizonResponseErr) + }) + } +} + +func Test_HorizonErrorWrapper_Error(t *testing.T) { + testCases := []struct { + name string + originalErr error + wantStringResult string + }{ + { + name: "non-horizon error", + originalErr: fmt.Errorf("something went wrong with TCP IP stuff"), + wantStringResult: "horizon response error: something went wrong with TCP IP stuff", + }, + { + name: "horizon error", + originalErr: horizonclient.Error{ + Problem: problem.P{ + Title: "Transaction Failed", + Type: "transaction_failed", + Status: http.StatusBadRequest, + Detail: "some-detail", + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_underfunded"}, + }, + }, + }, + }, + wantStringResult: `horizon response error: StatusCode=400, Type=transaction_failed, Title=Transaction Failed, Detail=some-detail, Extras=transaction: tx_failed - operation codes: [ op_underfunded ]`, + }, + { + name: "horizon error with less fields", + originalErr: horizonclient.Error{ + Problem: problem.P{ + Type: "transaction_failed", + Status: http.StatusBadRequest, + }, + }, + wantStringResult: "horizon response error: StatusCode=400, Type=transaction_failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + txStatusUpdateErr := NewHorizonErrorWrapper(tc.originalErr) + require.Equal(t, tc.wantStringResult, txStatusUpdateErr.Error()) + }) + } +} + +func Test_HorizonErrorWrapper_Unwrap_and_Is(t *testing.T) { + err := fmt.Errorf("some causing error") + horizonErrorWrapper := NewHorizonErrorWrapper(err) + require.Equal(t, err, horizonErrorWrapper.Unwrap()) + require.True(t, errors.Is(horizonErrorWrapper, err)) +} + +func Test_HorizonErrorWrapper_As(t *testing.T) { + err := fmt.Errorf("some causing error") + var someError error = NewHorizonErrorWrapper(err) + + var horizonErrorWrapper *HorizonErrorWrapper + require.True(t, errors.As(someError, &horizonErrorWrapper)) + require.NotNil(t, horizonErrorWrapper) + + err = fmt.Errorf("sandwich the error: %w", horizonErrorWrapper) + require.True(t, errors.As(err, &horizonErrorWrapper)) +} + +func Test_HorizonErrorWrapper_IsNotFound(t *testing.T) { + testCases := []struct { + name string + originalErr error + wantResult bool + }{ + { + name: "non-horizon error, returns FALSE", + originalErr: fmt.Errorf("something went wrong with TCP IP stuff"), + wantResult: false, + }, + { + name: "400 horizon error, returns FALSE", + originalErr: horizonclient.Error{Problem: problem.P{Status: http.StatusBadRequest}}, + wantResult: false, + }, + { + name: "404 horizon error, returns TRUE", + originalErr: horizonclient.Error{Problem: problem.P{Status: http.StatusNotFound}}, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + txStatusUpdateErr := NewHorizonErrorWrapper(tc.originalErr) + require.Equal(t, tc.wantResult, txStatusUpdateErr.IsNotFound()) + }) + } +} + +func Test_HorizonErrorWrapper_IsRateLimit(t *testing.T) { + testCases := []struct { + name string + originalErr error + wantResult bool + }{ + { + name: "non-horizon error, returns FALSE", + originalErr: fmt.Errorf("something went wrong with TCP IP stuff"), + wantResult: false, + }, + { + name: "400 horizon error, returns FALSE", + originalErr: horizonclient.Error{Problem: problem.P{Status: http.StatusBadRequest}}, + wantResult: false, + }, + { + name: "429 horizon error, returns TRUE", + originalErr: horizonclient.Error{Problem: problem.P{Status: http.StatusTooManyRequests}}, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + txStatusUpdateErr := NewHorizonErrorWrapper(tc.originalErr) + require.Equal(t, tc.wantResult, txStatusUpdateErr.IsRateLimit()) + }) + } +} + +func Test_HorizonErrorWrapper_IsGatewayTimeout(t *testing.T) { + testCases := []struct { + name string + originalErr error + wantResult bool + }{ + { + name: "non-horizon error, returns FALSE", + originalErr: fmt.Errorf("something went wrong with TCP IP stuff"), + wantResult: false, + }, + { + name: "400 horizon error, returns FALSE", + originalErr: horizonclient.Error{Problem: problem.P{Status: http.StatusBadRequest}}, + wantResult: false, + }, + { + name: "504 horizon error, returns TRUE", + originalErr: horizonclient.Error{Problem: problem.P{Status: http.StatusGatewayTimeout}}, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + txStatusUpdateErr := NewHorizonErrorWrapper(tc.originalErr) + require.Equal(t, tc.wantResult, txStatusUpdateErr.IsGatewayTimeout()) + }) + } +} + +func Test_HorizonErrorWrapper_handleExtrasResultCodes(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult string + }{ + { + name: "doesn't write any content when there's no result codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "doesn't write any content when result codes is empty", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "writes the content of transaction key", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + }, + }, + }, + }, + wantResult: ", Extras=transaction: tx_fee_bump_inner_failed", + }, + { + name: "writes the content of inner_transaction key", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "inner_transaction": "tx_too_early", + }, + }, + }, + }, + wantResult: ", Extras=inner transaction: tx_too_early", + }, + { + name: "writes the content of operations key", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_failed_1", "op_failed_2", "op_failed_3"}, + }, + }, + }, + }, + wantResult: ", Extras=operation codes: [ op_failed_1, op_failed_2, op_failed_3 ]", + }, + { + name: "writes the content of transaction and inner_transaction keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_too_early", + }, + }, + }, + }, + wantResult: ", Extras=transaction: tx_fee_bump_inner_failed - inner transaction: tx_too_early", + }, + { + name: "writes the content of all keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_too_early", + "operations": []string{"op_failed_1", "op_failed_2", "op_failed_3"}, + }, + }, + }, + }, + wantResult: ", Extras=transaction: tx_fee_bump_inner_failed - inner transaction: tx_too_early - operation codes: [ op_failed_1, op_failed_2, op_failed_3 ]", + }, + } + + msgBuilder := new(strings.Builder) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + wrapper.handleExtrasResultCodes(msgBuilder) + + if tc.wantResult == "" { + assert.Empty(t, msgBuilder.String()) + } else { + assert.Contains(t, msgBuilder.String(), tc.wantResult) + } + + msgBuilder.Reset() + }) + } +} + +func Test_HorizonErrorWrapper_IsNotEnoughLumens(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to tx_insufficient_balance or op_underfunded", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed_1"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the transaction key is tx_insufficient_balance", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_insufficient_balance", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the inner_transaction key is tx_insufficient_balance", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_insufficient_balance", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the operations key contains op_underfunded", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_underfunded"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsNotEnoughLumens()) + }) + } +} + +func Test_HorizonErrorWrapper_IsNoSourceAccount(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to tx_no_source_account or op_no_source_account", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the transaction key is tx_no_source_account", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_no_source_account", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the inner_transaction key is tx_no_source_account", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_no_source_account", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the operations key contains op_no_source_account", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_inner_transaction", + "operations": []string{"op_no_source_account"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsNoSourceAccount()) + }) + } +} + +func Test_HorizonErrorWrapper_IsNoIssuer(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_no_issuer", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_no_issuer", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_no_issuer"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsNoIssuer()) + }) + } +} + +func Test_HorizonErrorWrapper_IsSourceAccountNotAuthorized(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_src_not_authorized", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_src_not_authorized", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_src_not_authorized"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsSourceAccountNotAuthorized()) + }) + } +} + +func Test_HorizonErrorWrapper_IsSourceNoTrustline(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_src_no_trust", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_src_no_trust", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_src_no_trust"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsSourceNoTrustline()) + }) + } +} + +func Test_HorizonErrorWrapper_IsDestinationAccountNotAuthorized(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_not_authorized", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_not_authorized", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_not_authorized"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsDestinationAccountNotAuthorized()) + }) + } +} + +func Test_HorizonErrorWrapper_IsDestinationNoTrustline(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_no_trust", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_no_trust", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_no_trust"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsDestinationNoTrustline()) + }) + } +} + +func Test_HorizonErrorWrapper_IsLineFull(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_line_full", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_line_full", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_line_full"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsLineFull()) + }) + } +} + +func Test_HorizonErrorWrapper_IsNoDestinationAccount(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to op_no_destination", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the operations key contains op_no_destination", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "operations": []string{"op_no_destination"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsNoDestinationAccount()) + }) + } +} + +func Test_HorizonErrorWrapper_IsBadAuthentication(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to tx_bad_auth, tx_bad_auth_extra, or op_bad_auth", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when the transaction key is tx_bad_auth", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_bad_auth", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the inner_transaction key is tx_bad_auth", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_bad_auth", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the transaction key is tx_bad_auth_extra", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_bad_auth_extra", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the inner_transaction key is tx_bad_auth_extra", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_bad_auth_extra", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the operations key contains op_bad_auth", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_inner_transaction", + "operations": []string{"op_bad_auth"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsBadAuthentication()) + }) + } +} + +func Test_HorizonErrorWrapper_IsTxInsufficientFee(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to tx_insufficient_fee", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + }, + { + name: "returns true when the transaction key is tx_insufficient_fee", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_insufficient_fee", + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsTxInsufficientFee()) + }) + } +} + +func Test_HorizonErrorWrapper_IsSourceAccountNotReady(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to any Source Account misconfiguration", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when source account has not enough lumens", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_insufficient_balance", + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the source account does not exist", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "tx_no_source_account", + "operations": []string{"op_no_source_account"}, + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the source account is not authorized to send the asset", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_src_not_authorized"}, + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the source account does not have trustline for the asset", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_src_no_trust"}, + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the source account is underfunded", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_underfunded"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsSourceAccountNotReady()) + }) + } +} + +func Test_HorizonErrorWrapper_IsDestinationAccountNotReady(t *testing.T) { + testCases := []struct { + name string + hErr error + wantResult bool + }{ + { + name: "returns false when there's no result_codes", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{}, + }, + }, + }, + { + name: "returns false when result_codes has no keys", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{}, + }, + }, + }, + }, + { + name: "returns false when the result_codes is not related to any Destination Account misconfiguration", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "inner_transaction": "inner_tx_failed", + "operations": []string{"op_failed"}, + }, + }, + }, + }, + wantResult: false, + }, + { + name: "returns true when destination account is not authorized to receive the asset", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_not_authorized"}, + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the destination account has no trustline for the asset", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_no_trust"}, + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the destination account does not exist", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_no_destination"}, + }, + }, + }, + }, + wantResult: true, + }, + { + name: "returns true when the destination account has no sufficient limit", + hErr: horizonclient.Error{ + Problem: problem.P{ + Status: http.StatusBadRequest, + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_fee_bump_inner_failed", + "operations": []string{"op_line_full"}, + }, + }, + }, + }, + wantResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrapper := NewHorizonErrorWrapper(tc.hErr) + assert.Equal(t, tc.wantResult, wrapper.IsDestinationAccountNotReady()) + }) + } +} diff --git a/internal/transactionsubmission/utils/mocks.go b/internal/transactionsubmission/utils/mocks.go new file mode 100644 index 000000000..a5af0b192 --- /dev/null +++ b/internal/transactionsubmission/utils/mocks.go @@ -0,0 +1,20 @@ +package utils + +import "github.com/stretchr/testify/mock" + +type PrivateKeyEncrypterMock struct { + mock.Mock +} + +func (pke *PrivateKeyEncrypterMock) Encrypt(message, passphrase string) (string, error) { + args := pke.Called(message, passphrase) + return args.String(0), args.Error(1) +} + +func (pke *PrivateKeyEncrypterMock) Decrypt(message, passphrase string) (string, error) { + args := pke.Called(message, passphrase) + return args.String(0), args.Error(1) +} + +// Making sure that PrivateKeyEncrypterMock implements PrivateKeyEncrypter +var _ PrivateKeyEncrypter = (*PrivateKeyEncrypterMock)(nil) diff --git a/internal/transactionsubmission/utils/test_helpers.go b/internal/transactionsubmission/utils/test_helpers.go new file mode 100644 index 000000000..b67d4d618 --- /dev/null +++ b/internal/transactionsubmission/utils/test_helpers.go @@ -0,0 +1,38 @@ +package utils + +import ( + "sync" + "testing" + "time" +) + +// WaitUntilWaitGroupIsDoneOrTimeout is a helper function that waits for a wait group to finish or times out after a +// given duration. This is used for test purposes. +func WaitUntilWaitGroupIsDoneOrTimeout(t *testing.T, wg *sync.WaitGroup, timeout time.Duration, shouldTimeout bool, assertFn func()) { + t.Helper() + + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + + select { + case <-ch: + if shouldTimeout { + t.Fatal("wait group finished, but we expected it to timeout") + } else { + t.Log("wait group finished as expected") + } + case <-time.After(timeout): + if shouldTimeout { + t.Log("wait group correctly timed out") + } else { + t.Fatal("wait group did not finish within the expected time") + } + } + + if assertFn != nil { + assertFn() + } +} diff --git a/internal/transactionsubmission/utils/utils.go b/internal/transactionsubmission/utils/utils.go new file mode 100644 index 000000000..f50b081a0 --- /dev/null +++ b/internal/transactionsubmission/utils/utils.go @@ -0,0 +1,46 @@ +package utils + +import ( + "context" + "fmt" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + sdpUtils "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +// GetHorizonErrorString returns a string representation of a horizonclient.Error. +func GetHorizonErrorString(hError horizonclient.Error) string { + hProblem := hError.Problem + return fmt.Sprintf("Type: %s, Title: %s, Status: %d, Detail: %s, Extras: %v", hProblem.Type, hProblem.Title, hProblem.Status, hProblem.Detail, hProblem.Extras) +} + +// AcquireAdvisoryLock attempt to acquire an advisory lock on the provided lockKey, returns true if acquired, or false +// not. +func AcquireAdvisoryLock(ctx context.Context, dbConnectionPool db.DBConnectionPool, lockKey int) (bool, error) { + tssAdvisoryLockAcquired := false + sqlQuery := "SELECT pg_try_advisory_lock($1)" + err := dbConnectionPool.QueryRowxContext(ctx, sqlQuery, lockKey).Scan(&tssAdvisoryLockAcquired) + if err != nil { + return false, fmt.Errorf("querying pg_try_advisory_lock(%v): %w", lockKey, err) + } + return tssAdvisoryLockAcquired, nil +} + +type PrivateKeyEncrypter interface { + Encrypt(message string, passphrase string) (string, error) + Decrypt(message string, passphrase string) (string, error) +} + +type DefaultPrivateKeyEncrypter struct{} + +func (e DefaultPrivateKeyEncrypter) Encrypt(message, passphrase string) (string, error) { + return sdpUtils.Encrypt(message, passphrase) +} + +func (e DefaultPrivateKeyEncrypter) Decrypt(message, passphrase string) (string, error) { + return sdpUtils.Decrypt(message, passphrase) +} + +// Making sure that DefaultPrivateKeyEncrypter implements PrivateKeyEncrypter +var _ PrivateKeyEncrypter = (*DefaultPrivateKeyEncrypter)(nil) diff --git a/internal/transactionsubmission/utils/utils_test.go b/internal/transactionsubmission/utils/utils_test.go new file mode 100644 index 000000000..547209424 --- /dev/null +++ b/internal/transactionsubmission/utils/utils_test.go @@ -0,0 +1,70 @@ +package utils + +import ( + "context" + "net/http" + "testing" + + "github.com/stellar/go/clients/horizonclient" + "github.com/stellar/go/support/render/problem" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stretchr/testify/require" +) + +func Test_GetHorizonErrorString(t *testing.T) { + hError := horizonclient.Error{ + Problem: problem.P{ + Title: "Transaction Failed", + Type: "transaction_failed", + Status: http.StatusBadRequest, + Detail: "", + Extras: map[string]interface{}{ + "result_codes": map[string]interface{}{ + "transaction": "tx_failed", + "operations": []string{"op_underfunded"}, + }, + }, + }, + } + + errStr := GetHorizonErrorString(hError) + wantErrStr := "Type: transaction_failed, Title: Transaction Failed, Status: 400, Detail: , Extras: map[result_codes:map[operations:[op_underfunded] transaction:tx_failed]]" + require.Equal(t, wantErrStr, errStr) +} + +func TestAdvisoryLockAndRelease(t *testing.T) { + ctx := context.Background() + // Creates a test database: + dbt := dbtest.OpenWithoutMigrations(t) + defer dbt.Close() + + // Creates a database pool + lockKey := 123 + dbConnectionPool1, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + lockAcquired, err := AcquireAdvisoryLock(ctx, dbConnectionPool1, lockKey) + require.NoError(t, err) + + // Should be able to acquire the lock + require.True(t, lockAcquired) + require.NoError(t, err) + + // Create another database pool + dbConnectionPool2, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool2.Close() + lockAcquired2, err := AcquireAdvisoryLock(ctx, dbConnectionPool2, lockKey) + require.NoError(t, err) + // Should not be able to acquire the lock since its already been acquired + require.False(t, lockAcquired2) + + // Close the original connection which releases the lock + dbConnectionPool1.Close() + + // try to acquire the lock again + lockAcquired3, err := AcquireAdvisoryLock(ctx, dbConnectionPool2, lockKey) + require.NoError(t, err) + // Should be able to acquire the lock since we called dbConnectionPool1.Close() + require.True(t, lockAcquired3) +} diff --git a/internal/utils/crypto.go b/internal/utils/crypto.go new file mode 100644 index 000000000..4a0e1ee4b --- /dev/null +++ b/internal/utils/crypto.go @@ -0,0 +1,78 @@ +package utils + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +const keyBytes = 16 + +// Encrypt secures a message using the AES GCM cipher mode which requires the use of +// a passphrase for authentication. +func Encrypt(message string, passphrase string) (string, error) { + passHash := sha256.New() + passHash.Write([]byte(passphrase)) + + key := make([]byte, keyBytes) + copy(key, passHash.Sum(nil)) + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + + gcmCipher, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcmCipher.NonceSize()) + lenRead, err := rand.Read(nonce) + if err != nil { + return "", fmt.Errorf("error while generating random nonce: %w", err) + } + if lenRead != gcmCipher.NonceSize() { + return "", fmt.Errorf("length of generated nonce %d different from expected length %d", lenRead, gcmCipher.NonceSize()) + } + + cipheredText := gcmCipher.Seal(nonce, nonce, []byte(message), nil) + return base64.StdEncoding.EncodeToString(cipheredText), nil +} + +// Decrypt recovers the original message from a secured one generated by Encrypt. +func Decrypt(message string, passphrase string) (string, error) { + passHash := sha256.New() + passHash.Write([]byte(passphrase)) + + key := make([]byte, keyBytes) + copy(key, passHash.Sum(nil)) + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + + gcmCipher, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + decodedMsg, err := base64.StdEncoding.DecodeString(message) + if err != nil { + return "", err + } + + nonceSize := gcmCipher.NonceSize() + nonce, cipheredText := decodedMsg[:nonceSize], decodedMsg[nonceSize:] + + plainText, err := gcmCipher.Open(nil, nonce, cipheredText, nil) + if err != nil { + return "", err + } + + return string(plainText), nil +} diff --git a/internal/utils/crypto_test.go b/internal/utils/crypto_test.go new file mode 100644 index 000000000..5a261dc28 --- /dev/null +++ b/internal/utils/crypto_test.go @@ -0,0 +1,33 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_EncryptAndDecrypt_Success(t *testing.T) { + message := "SBJZIXEH2VE4VQRMWUSYL3PPIOPPKVR5W3LHIZUV46YB22TAB7H4AGBJ" + key := "1c4d3e4ec75106e0649825b0941fca423f752756a487847d29bb1a9704d17a70e4bac5d52be1933559bcfb43c7017b61d05f4252063f9135b270e8ea99016c03" + + encrypted, err := Encrypt(message, key) + require.NoError(t, err) + + decrypted, err := Decrypt(encrypted, key) + require.NoError(t, err) + + assert.Equal(t, message, decrypted) +} + +func Test_EncryptAndDecrypt_AuthenticationFailure(t *testing.T) { + message := "SBJZIXEH2VE4VQRMWUSYL3PPIOPPKVR5W3LHIZUV46YB22TAB7H4AGBJ" + encryptKey := "9761343c0518b89d92168804c7d7edfc74da8aef8b498d54873836c47c33641bd76b7bdccef361125c638951998076887c6445f11bd0be40feb7cfd4168857e3" + decryptKey := "1c4d3e4ec75106e0649825b0941fca423f752756a487847d29bb1a9704d17a70e4bac5d52be1933559bcfb43c7017b61d05f4252063f9135b270e8ea99016c03" + + encrypted, err := Encrypt(message, encryptKey) + require.NoError(t, err) + + _, err = Decrypt(encrypted, decryptKey) + require.Error(t, err) +} diff --git a/internal/utils/ecdsa.go b/internal/utils/ecdsa.go new file mode 100644 index 000000000..b607c3bb8 --- /dev/null +++ b/internal/utils/ecdsa.go @@ -0,0 +1,85 @@ +package utils + +import ( + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "fmt" +) + +// ParseECDSAPublicKey parses the given public key string and returns the *ecdsa.PublicKey. +func ParseECDSAPublicKey(publicKeyStr string) (*ecdsa.PublicKey, error) { + // Decode PEM block + block, _ := pem.Decode([]byte(publicKeyStr)) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block containing public key") + } + + // Parse the public key + pkixPublicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse x509 PKIX public key: %w", err) + } + + // Check if the public key is of type *ecdsa.PublicKey + publicKey, ok := pkixPublicKey.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("public key is not of type ECDSA") + } + + return publicKey, nil +} + +// ParseECDSAPrivateKey parses the given private key string and returns the *ecdsa.PrivateKey. +func ParseECDSAPrivateKey(privateKeyStr string) (*ecdsa.PrivateKey, error) { + // Decode PEM block + block, _ := pem.Decode([]byte(privateKeyStr)) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block containing private key") + } + + // Parse the private key + pkcsPrivateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse EC private key: %w", err) + } + + // Check if the public key is of type *ecdsa.PublicKey + privateKey, ok := pkcsPrivateKey.(*ecdsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is not of type ECDSA") + } + + return privateKey, nil +} + +// ValidateECDSAKeys validates if the given public and private keys are a valid ECDSA keypair. +func ValidateECDSAKeys(publicKeyStr, privateKeyStr string) error { + publicKey, err := ParseECDSAPublicKey(publicKeyStr) + if err != nil { + return fmt.Errorf("validating ECDSA public key: %w", err) + } + + privateKey, err := ParseECDSAPrivateKey(privateKeyStr) + if err != nil { + return fmt.Errorf("validating ECDSA private key: %w", err) + } + + // Sign a test message using the private key + msg := "test message" + hash := sha256.Sum256([]byte(msg)) + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) + if err != nil { + return fmt.Errorf("signing message for validation: %w", err) + } + + // Verify the signature using the public key + valid := ecdsa.Verify(publicKey, hash[:], r, s) + if !valid { + return fmt.Errorf("signature verification failed") + } + + return nil +} diff --git a/internal/utils/ecdsa_test.go b/internal/utils/ecdsa_test.go new file mode 100644 index 000000000..8b66fd749 --- /dev/null +++ b/internal/utils/ecdsa_test.go @@ -0,0 +1,214 @@ +package utils + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ecdsaKeypair struct { + privateKeyStr string + publicKeyStr string +} + +var ( + keypair1 = ecdsaKeypair{ + publicKeyStr: `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAER88h7AiQyVDysRTxKvBB6CaiO/kS +cvGyimApUE/12gFhNTRf37SE19CSCllKxstnVFOpLLWB7Qu5OJ0Wvcz3hg== +-----END PUBLIC KEY-----`, + privateKeyStr: `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgIqI1MzMZIw2pQDLx +Jn0+FcNT/hNjwtn2TW43710JKZqhRANCAARHzyHsCJDJUPKxFPEq8EHoJqI7+RJy +8bKKYClQT/XaAWE1NF/ftITX0JIKWUrGy2dUU6kstYHtC7k4nRa9zPeG +-----END PRIVATE KEY-----`, + } + keypair2 = ecdsaKeypair{ + publicKeyStr: `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAERJtGEWVxHTOghAFU9XyANbF10aXK +zT3U72jUfBk38fceemINJERxdLbBs2O1foeFd8HyJ6Zn7tLvZWGNvVN+cA== +-----END PUBLIC KEY-----`, + privateKeyStr: `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgw8lMqTKWEdxusLOW +J16L7THmguSKZq1PPS1SRravKpOhRANCAAREm0YRZXEdM6CEAVT1fIA1sXXRpcrN +PdTvaNR8GTfx9x56Yg0kRHF0tsGzY7V+h4V3wfInpmfu0u9lYY29U35w +-----END PRIVATE KEY-----`, + } +) + +func Test_ParseECDSAPublicKey(t *testing.T) { + // publicKeyObj is the public key object that corresponds to the keypair1.publicKeyStr + bigIntX := new(big.Int) + bigIntX.SetString("32480183712899956666963574445105818726761898573293978186307012095310684346881", 10) + bigIntY := new(big.Int) + bigIntY.SetString("43968350682573962747988640660801043718476300246351425025163140929681875597190", 10) + publicKeyObj := &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: bigIntX, + Y: bigIntY, + } + + testCases := []struct { + name string + value string + wantResult *ecdsa.PublicKey + wantErrContains string + }{ + { + name: "returns an error if the value is not a PEM string", + value: "not-a-pem-string", + wantErrContains: "failed to decode PEM block containing public key", + }, + { + name: "returns an error if the value is not a x509 string", + value: "-----BEGIN MY STRING-----\nYWJjZA==\n-----END MY STRING-----", + wantErrContains: "failed to parse x509 PKIX public key", + }, + { + name: "returns an error if the value is not a ECDSA public key", + value: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyNPqmozv8a2PnXHIkV+F\nmWMFy2YhOFzX12yzjjWkJ3rI9QSEomz4Unkwc6oYrnKEDYlnAgCiCqL2zPr5qNkX\nk5MPU87/wLgEqp7uAk0GkJZfrhJIYZ5AuG9+o69BNeQDEi7F3YdMJj9bvs2Ou1FN\n1zG/8HV969rJ/63fzWsqlNon1j4H5mJ0YbmVh/QLcYPmv7feFZGEj4OSZ4u+eJsw\nat5NPyhMgo6uB/goNS3fEY29UNvXoSIN3hnK3WSxQ79Rjn4V4so7ehxzCVPjnm/G\nFFTgY0hGBobmnxbjI08hEZmYKosjan4YqydGETjKR3UlhBx9y/eqqgL+opNJ8vJs\n2QIDAQAB\n-----END PUBLIC KEY-----", + wantErrContains: "public key is not of type ECDSA", + }, + { + name: "πŸŽ‰ Successfully handles a valid ECDSA public key", + value: keypair1.publicKeyStr, + wantResult: publicKeyObj, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotPublicKey, err := ParseECDSAPublicKey(tc.value) + if tc.wantErrContains == "" { + assert.NotNil(t, gotPublicKey) + assert.Equal(t, tc.wantResult, gotPublicKey) + assert.NoError(t, err) + } else { + assert.Nil(t, gotPublicKey) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrContains) + } + }) + } +} + +func Test_ParseECDSAPrivateKey(t *testing.T) { + // privateKeyObj is the public key object that corresponds to the keypair1.privateKeyStr + bigIntX := new(big.Int) + bigIntX.SetString("32480183712899956666963574445105818726761898573293978186307012095310684346881", 10) + bigIntY := new(big.Int) + bigIntY.SetString("43968350682573962747988640660801043718476300246351425025163140929681875597190", 10) + publicKeyObj := &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: bigIntX, + Y: bigIntY, + } + bigIntD := new(big.Int) + bigIntD.SetString("15665233249220082997812441880036381661021061746430729869708887737553839008154", 10) + privateKeyObj := &ecdsa.PrivateKey{ + PublicKey: *publicKeyObj, + D: bigIntD, + } + + testCases := []struct { + name string + value string + wantResult *ecdsa.PrivateKey + wantErrContains string + }{ + { + name: "returns an error if the value is not a PEM string", + value: "not-a-pem-string", + wantErrContains: "failed to decode PEM block containing private key", + }, + { + name: "returns an error if the value is not a x509 string", + value: "-----BEGIN MY STRING-----\nYWJjZA==\n-----END MY STRING-----", + wantErrContains: "failed to parse EC private key", + }, + { + name: "returns an error if the value is not a ECDSA private key", + value: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyNPqmozv8a2PnXHIkV+F\nmWMFy2YhOFzX12yzjjWkJ3rI9QSEomz4Unkwc6oYrnKEDYlnAgCiCqL2zPr5qNkX\nk5MPU87/wLgEqp7uAk0GkJZfrhJIYZ5AuG9+o69BNeQDEi7F3YdMJj9bvs2Ou1FN\n1zG/8HV969rJ/63fzWsqlNon1j4H5mJ0YbmVh/QLcYPmv7feFZGEj4OSZ4u+eJsw\nat5NPyhMgo6uB/goNS3fEY29UNvXoSIN3hnK3WSxQ79Rjn4V4so7ehxzCVPjnm/G\nFFTgY0hGBobmnxbjI08hEZmYKosjan4YqydGETjKR3UlhBx9y/eqqgL+opNJ8vJs\n2QIDAQAB\n-----END PUBLIC KEY-----", + wantErrContains: "failed to parse EC private key", + }, + { + name: "πŸŽ‰ Successfully handles a valid ECDSA private key", + wantResult: privateKeyObj, + value: keypair1.privateKeyStr, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotPrivateKey, err := ParseECDSAPrivateKey(tc.value) + if tc.wantErrContains == "" { + assert.Equal(t, tc.wantResult, gotPrivateKey) + assert.NoError(t, err) + } else { + assert.Nil(t, gotPrivateKey) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrContains) + } + }) + } +} + +func Test_ValidateECDSAKeys(t *testing.T) { + testCases := []struct { + name string + publicKeyStr string + privateKeyStr string + wantErrContains string + }{ + { + name: "returns an error if the public key is not a PEM string", + publicKeyStr: "not-a-pem-string", + privateKeyStr: keypair1.privateKeyStr, + wantErrContains: "validating ECDSA public key: failed to decode PEM block containing public key", + }, + { + name: "returns an error if the public key is valid but the private key is not a x509 string", + publicKeyStr: keypair1.publicKeyStr, + privateKeyStr: "-----BEGIN MY STRING-----\nYWJjZA==\n-----END MY STRING-----", + wantErrContains: "validating ECDSA private key: failed to parse EC private key", + }, + { + name: "returns an error if the keys are not a pair (1 & 2)", + publicKeyStr: keypair1.publicKeyStr, + privateKeyStr: keypair2.privateKeyStr, + wantErrContains: "signature verification failed", + }, + { + name: "returns an error if the keys are not a pair (2 & 1)", + publicKeyStr: keypair2.publicKeyStr, + privateKeyStr: keypair1.privateKeyStr, + wantErrContains: "signature verification failed", + }, + { + name: "πŸŽ‰ Successfully validates a valid ECDSA key pair (1)", + publicKeyStr: keypair1.publicKeyStr, + privateKeyStr: keypair1.privateKeyStr, + }, + { + name: "πŸŽ‰ Successfully validates a valid ECDSA key pair (2)", + publicKeyStr: keypair2.publicKeyStr, + privateKeyStr: keypair2.privateKeyStr, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateECDSAKeys(tc.publicKeyStr, tc.privateKeyStr) + if tc.wantErrContains == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrContains) + } + }) + } +} diff --git a/internal/utils/float.go b/internal/utils/float.go new file mode 100644 index 000000000..98324e092 --- /dev/null +++ b/internal/utils/float.go @@ -0,0 +1,8 @@ +package utils + +import "strconv" + +// FloatToString converts a float number to a string with 7 decimal places. +func FloatToString(inputNum float64) string { + return strconv.FormatFloat(inputNum, 'f', 7, 64) +} diff --git a/internal/utils/float_test.go b/internal/utils/float_test.go new file mode 100644 index 000000000..d200da138 --- /dev/null +++ b/internal/utils/float_test.go @@ -0,0 +1,27 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_FloatToString(t *testing.T) { + testCases := []struct { + floatInput float64 + wantStringOutput string + }{ + {floatInput: 1.2345678, wantStringOutput: "1.2345678"}, + {floatInput: 1.23456784, wantStringOutput: "1.2345678"}, + {floatInput: 1.23456789, wantStringOutput: "1.2345679"}, + {floatInput: 1.0, wantStringOutput: "1.0000000"}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf(tc.wantStringOutput), func(t *testing.T) { + gotStringOutput := FloatToString(tc.floatInput) + assert.Equal(t, tc.wantStringOutput, gotStringOutput) + }) + } +} diff --git a/internal/utils/network_type.go b/internal/utils/network_type.go new file mode 100644 index 000000000..7cec3f01d --- /dev/null +++ b/internal/utils/network_type.go @@ -0,0 +1,25 @@ +package utils + +import ( + "fmt" + + "github.com/stellar/go/network" +) + +type NetworkType string + +const ( + PubnetNetworkType NetworkType = "pubnet" + TestnetNetworkType NetworkType = "testnet" +) + +func GetNetworkTypeFromNetworkPassphrase(networkPassphrase string) (NetworkType, error) { + switch networkPassphrase { + case network.PublicNetworkPassphrase: + return PubnetNetworkType, nil + case network.TestNetworkPassphrase: + return TestnetNetworkType, nil + default: + return "", fmt.Errorf("invalid network passphrase provided") + } +} diff --git a/internal/utils/network_type_test.go b/internal/utils/network_type_test.go new file mode 100644 index 000000000..19d5d891d --- /dev/null +++ b/internal/utils/network_type_test.go @@ -0,0 +1,42 @@ +package utils + +import ( + "testing" + + "github.com/stellar/go/network" + "github.com/stretchr/testify/assert" +) + +func Test_GetNetworkTypeFromNetworkPassphrase(t *testing.T) { + testCases := []struct { + networkPassphrase string + expectedNetworkType NetworkType + expectedError string + }{ + { + networkPassphrase: network.PublicNetworkPassphrase, + expectedNetworkType: PubnetNetworkType, + expectedError: "", + }, + { + networkPassphrase: network.TestNetworkPassphrase, + expectedNetworkType: TestnetNetworkType, + expectedError: "", + }, + { + networkPassphrase: "invalid", + expectedNetworkType: "", + expectedError: "invalid network passphrase provided", + }, + } + + for _, tc := range testCases { + networkType, err := GetNetworkTypeFromNetworkPassphrase(tc.networkPassphrase) + assert.Equal(t, tc.expectedNetworkType, networkType) + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + } else { + assert.Nil(t, err) + } + } +} diff --git a/internal/utils/string.go b/internal/utils/string.go new file mode 100644 index 000000000..e072e336d --- /dev/null +++ b/internal/utils/string.go @@ -0,0 +1,40 @@ +package utils + +import ( + "crypto/rand" + "fmt" + "math/big" +) + +const ( + letterBytes = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + NumberBytes = "0123456789" +) + +func RandomString(size int, charSetOptions ...string) (string, error) { + charSet := letterBytes + if len(charSetOptions) > 0 { + charSet = "" + for _, cs := range charSetOptions { + charSet += cs + } + } + + b := make([]byte, size) + for i := range b { + randInt, err := rand.Int(rand.Reader, big.NewInt(int64(len(charSet)))) + if err != nil { + return "", fmt.Errorf("error generating random number in RandomString: %w", err) + } + + b[i] = charSet[randInt.Int64()] + } + return string(b), nil +} + +func TruncateString(str string, borderSizeToKeep int) string { + if len(str) <= 2*borderSizeToKeep { + return str + } + return str[:borderSizeToKeep] + "..." + str[len(str)-borderSizeToKeep:] +} diff --git a/internal/utils/string_test.go b/internal/utils/string_test.go new file mode 100644 index 000000000..1837f880e --- /dev/null +++ b/internal/utils/string_test.go @@ -0,0 +1,70 @@ +package utils + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_RandomString(t *testing.T) { + randomString1, err := RandomString(10) + require.NoError(t, err) + require.Len(t, randomString1, 10) + randomString2, err := RandomString(10) + require.NoError(t, err) + require.Len(t, randomString2, 10) + require.NotEqual(t, randomString1, randomString2) + + randomString3, err := RandomString(5) + require.NoError(t, err) + require.Len(t, randomString3, 5) + + randomString4, err := RandomString(6, NumberBytes) + require.NoError(t, err) + require.Len(t, randomString4, 6) + onlyNumbers := regexp.MustCompile(`\d`).MatchString(randomString4) + assert.True(t, onlyNumbers) +} + +func Test_TruncateString(t *testing.T) { + testCases := []struct { + name string + rawString string + borderSizeToKeep int + wantTruncated string + }{ + { + name: "string is shorter than borderSizeToKeep", + rawString: "abc", + borderSizeToKeep: 4, + wantTruncated: "abc", + }, + { + name: "string is longer than borderSizeToKeep", + rawString: "abcdefg", + borderSizeToKeep: 3, + wantTruncated: "abc...efg", + }, + { + name: "string is same length as borderSizeToKeep", + rawString: "abcdef", + borderSizeToKeep: 3, + wantTruncated: "abcdef", + }, + { + name: "string is empty", + rawString: "", + borderSizeToKeep: 3, + wantTruncated: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotTruncated := TruncateString(tc.rawString, tc.borderSizeToKeep) + assert.Equal(t, tc.wantTruncated, gotTruncated, "Expected Truncate(%q, %d) to be %q, but got %q", tc.rawString, tc.borderSizeToKeep, tc.wantTruncated, gotTruncated) + }) + } +} diff --git a/internal/utils/url.go b/internal/utils/url.go new file mode 100644 index 000000000..bfc049184 --- /dev/null +++ b/internal/utils/url.go @@ -0,0 +1,74 @@ +package utils + +import ( + "encoding/hex" + "fmt" + "net/url" + + "github.com/stellar/go/keypair" +) + +func SignURL(stellarSecretKey string, rawURL string) (string, error) { + // Validate stellar private key + kp, err := keypair.ParseFull(stellarSecretKey) + if err != nil { + return "", fmt.Errorf("error parsing stellar private key: %w", err) + } + + // Validate raw url + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("error parsing raw url: %w", err) + } + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("raw url %q should have both a scheme and a host", rawURL) + } + + // Sign url + u.RawQuery = u.Query().Encode() + signature, err := kp.Sign([]byte(u.String())) + if err != nil { + return "", fmt.Errorf("error signing url: %w", err) + } + signatureHex := hex.EncodeToString(signature) + signedURL := u.String() + "&signature=" + signatureHex + + return signedURL, nil +} + +func VerifySignedURL(signedURL string, expectedPublicKey string) (bool, error) { + // Validate expected public key + pubKey, err := keypair.ParseAddress(expectedPublicKey) + if err != nil { + return false, fmt.Errorf("error parsing expected public key: %w", err) + } + + // Validate signed URL + u, err := url.Parse(signedURL) + if err != nil { + return false, fmt.Errorf("error parsing signed url: %w", err) + } + + // Extract signature from signed URL + query := u.Query() + signatureHex := query.Get("signature") + if signatureHex == "" { + return false, fmt.Errorf("signed url does not contain a signature") + } + signature, err := hex.DecodeString(signatureHex) + if err != nil { + return false, fmt.Errorf("error decoding signature: %w", err) + } + + // Remove signature from URL + query.Del("signature") + u.RawQuery = query.Encode() + + // Verify signature + err = pubKey.Verify([]byte(u.String()), signature) + if err != nil { + return false, fmt.Errorf("error verifying URL signature: %w", err) + } + + return true, nil +} diff --git a/internal/utils/url_test.go b/internal/utils/url_test.go new file mode 100644 index 000000000..5968af24e --- /dev/null +++ b/internal/utils/url_test.go @@ -0,0 +1,125 @@ +package utils + +import ( + "strings" + "testing" + + "github.com/stellar/go/keypair" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SignURL(t *testing.T) { + // rawURL := https://vibrantapp.com/sdp-dev?domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar%20Test&asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5 + // PublicKey: GBFDUUZ5ZYC6RAPOQLM7IYXLFHYTMCYXBGM7NIC4EE2MWOSGIYCOSN5F + // PrivateKey: SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5 + // result: https://vibrantapp.com/sdp-dev?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar+Test&signature=60bb8ed15df271131bb2d7c87fd5649a9a69bf655c5ffcff3816c766cfd98356381a7d4c03494c4bb9eb25e1167a399845aae73ec667990d840e9fc43af6e906 + + testCases := []struct { + name string + stellarSecretKey string + rawURL string + wantSignedURL string + wantErrContains string + }{ + { + name: "returns an error if stellarSecretKey is empty", + wantErrContains: "error parsing stellar private key: strkey is 0 bytes long; minimum valid length is 5", + }, + { + name: "returns an error if stellarSecretKey is invalid", + stellarSecretKey: "INVALID_SECRET_KEY", + wantErrContains: "error parsing stellar private key: base32 decode failed: illegal base32 data at input byte 7", + }, + { + name: "returns an error if rawURL is empty", + stellarSecretKey: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + wantErrContains: `raw url "" should have both a scheme and a host`, + }, + { + name: "returns an error if rawURL has a host without scheme", + stellarSecretKey: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + rawURL: "host-without-scheme", + wantErrContains: `raw url "host-without-scheme" should have both a scheme and a host`, + }, + { + name: "returns an error if rawURL has a scheme without host", + stellarSecretKey: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + rawURL: "scheme-without-host://", + wantErrContains: `raw url "scheme-without-host://" should have both a scheme and a host`, + }, + { + name: "πŸŽ‰ successfully signs the desired url", + stellarSecretKey: "SBUSPEKAZKLZSWHRSJ2HWDZUK6I3IVDUWA7JJZSGBLZ2WZIUJI7FPNB5", + rawURL: "https://vibrantapp.com/sdp-dev?domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar%20Test&asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5", + wantSignedURL: "https://vibrantapp.com/sdp-dev?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar+Test&signature=fea6c5e805a29b903835bea2f6c60069113effdf1c5cb448d4948573c65557b1d667bcd176c24a94ed9d54a1829317c74f39319076511512a3e697b4b746ae0a", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotSignedURL, err := SignURL(tc.stellarSecretKey, tc.rawURL) + if tc.wantErrContains != "" { + assert.Empty(t, gotSignedURL) + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrContains) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantSignedURL, gotSignedURL) + } + }) + } +} + +func Test_VerifySignedURL(t *testing.T) { + // signedURL example from previous test + signedURL := "https://vibrantapp.com/sdp-dev/aid?asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5&domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar+Test&signature=60bb8ed15df271131bb2d7c87fd5649a9a69bf655c5ffcff3816c766cfd98356381a7d4c03494c4bb9eb25e1167a399845aae73ec667990d840e9fc43af6e906" + expectedPublicKey := "GBFDUUZ5ZYC6RAPOQLM7IYXLFHYTMCYXBGM7NIC4EE2MWOSGIYCOSN5F" + + // expectedPublicKey cannot be empty + isValid, err := VerifySignedURL(signedURL, "") + require.False(t, isValid) + require.EqualError(t, err, "error parsing expected public key: strkey is 0 bytes long; minimum valid length is 5") + + // invalid expectedPublicKey + isValid, err = VerifySignedURL(signedURL, "INVALID_PUBLIC_KEY") + require.False(t, isValid) + require.EqualError(t, err, "error parsing expected public key: base32 decode failed: illegal base32 data at input byte 7") + + // signedURL cannot be empty + isValid, err = VerifySignedURL("", expectedPublicKey) + require.False(t, isValid) + require.EqualError(t, err, "signed url does not contain a signature") + + // invalid signedURL + isValid, err = VerifySignedURL("invalid_signed_url", expectedPublicKey) + require.False(t, isValid) + require.EqualError(t, err, "signed url does not contain a signature") + + // valid signedURL and expectedPublicKey πŸŽ‰ + isValid, err = VerifySignedURL(signedURL, expectedPublicKey) + require.NoError(t, err) + require.True(t, isValid) + + // valid signedURL and expectedPublicKey but signature is invalid + tamperedURL := strings.Replace(signedURL, "USDC", "USD", 1) + isValid, err = VerifySignedURL(tamperedURL, expectedPublicKey) + require.False(t, isValid) + require.EqualError(t, err, "error verifying URL signature: signature verification failed") +} + +func Test_SignURL_VerifySignedURL(t *testing.T) { + kp, err := keypair.Random() + require.NoError(t, err) + + // valid rawURL and stellarSecretKey πŸŽ‰ + validURL := "https://vibrantapp.com/sdp-dev/aid?domain=ap-stellar-disbursement-platform-backend-dev.stellar.org&name=Stellar%20Test&asset=USDC-GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5" + gotSignedURL, err := SignURL(kp.Seed(), validURL) + require.NoError(t, err) + require.NotEmpty(t, gotSignedURL) + + // valid signedURL and expectedPublicKey πŸŽ‰ + isValid, err := VerifySignedURL(gotSignedURL, kp.Address()) + require.NoError(t, err) + require.True(t, isValid) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 000000000..4e6f25a77 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,52 @@ +package utils + +import ( + "net/http" + "reflect" + + "github.com/go-chi/chi/v5" +) + +func GetRoutePattern(r *http.Request) string { + rctx := chi.RouteContext(r.Context()) + if pattern := rctx.RoutePattern(); pattern != "" { + // Pattern is already available + return pattern + } + + routePath := r.URL.Path + + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } + + tctx := chi.NewRouteContext() + if !rctx.Routes.Match(tctx, r.Method, routePath) { + return "undefined" + } + + // tctx has the updated pattern, since Match mutates it + return tctx.RoutePattern() +} + +// UnwrapInterfaceToPointer unwraps an interface to a pointer of the given type. +func UnwrapInterfaceToPointer[T any](i interface{}) *T { + t, ok := i.(*T) + if ok { + return t + } + return nil +} + +// IsEmpty checks if a value is empty. +func IsEmpty[T any](v T) bool { + return reflect.ValueOf(&v).Elem().IsZero() +} + +func MapSlice[T any, M any](a []T, f func(T) M) []M { + n := make([]M, len(a)) + for i, e := range a { + n[i] = f(e) + } + return n +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 000000000..2f41d966f --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,149 @@ +package utils + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetRoutePattern(t *testing.T) { + testCases := []struct { + expectedRoutePattern string + method string + }{ + {expectedRoutePattern: "/mock", method: "GET"}, + {expectedRoutePattern: "undefined", method: "POST"}, + } + + mHttpHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + for _, tc := range testCases { + t.Run("getting route pattern", func(t *testing.T) { + mAssertRoutePattern := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + routePattern := GetRoutePattern(req) + + assert.Equal(t, tc.expectedRoutePattern, routePattern) + next.ServeHTTP(rw, req) + }) + } + + r := chi.NewRouter() + r.Use(mAssertRoutePattern) + r.Get("/mock", mHttpHandler.ServeHTTP) + + req, err := http.NewRequest(tc.method, "/mock", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + }) + } +} + +func Test_UnwrapInterfaceToPointer(t *testing.T) { + // Test with a string + strValue := "test" + strValuePtr := &strValue + i := interface{}(strValuePtr) + + unwrappedValue := UnwrapInterfaceToPointer[string](i) + assert.Equal(t, "test", *unwrappedValue) + + // Test with a struct + type testStruct struct { + Name string + } + testStructValue := testStruct{Name: "test"} + testStructValuePtr := &testStructValue + i = interface{}(testStructValuePtr) + assert.Equal(t, testStruct{Name: "test"}, *UnwrapInterfaceToPointer[testStruct](i)) +} + +func Test_IsEmpty(t *testing.T) { + type testCase struct { + name string + isEmptyFn func() bool + expected bool + } + + // testStruct is used just for testing empty and non empty structs. + type testStruct struct{ Name string } + + // Define test cases + testCases := []testCase{ + // String + {name: "String empty", isEmptyFn: func() bool { return IsEmpty[string]("") }, expected: true}, + {name: "String non-empty", isEmptyFn: func() bool { return IsEmpty[string]("not empty") }, expected: false}, + // Int + {name: "Int zero", isEmptyFn: func() bool { return IsEmpty[int](0) }, expected: true}, + {name: "Int non-zero", isEmptyFn: func() bool { return IsEmpty[int](1) }, expected: false}, + // Slice: + {name: "Slice nil", isEmptyFn: func() bool { return IsEmpty[[]string](nil) }, expected: true}, + {name: "Slice empty", isEmptyFn: func() bool { return IsEmpty[[]string]([]string{}) }, expected: false}, + {name: "Slice non-empty", isEmptyFn: func() bool { return IsEmpty[[]string]([]string{"not empty"}) }, expected: false}, + // Struct: + {name: "Struct zero", isEmptyFn: func() bool { return IsEmpty[testStruct](testStruct{}) }, expected: true}, + {name: "Struct non-zero", isEmptyFn: func() bool { return IsEmpty[testStruct](testStruct{Name: "not empty"}) }, expected: false}, + // Pointer: + {name: "Pointer nil", isEmptyFn: func() bool { return IsEmpty[*string](nil) }, expected: true}, + {name: "Pointer non-nil", isEmptyFn: func() bool { return IsEmpty[*string](new(string)) }, expected: false}, + // Function: + {name: "Function nil", isEmptyFn: func() bool { return IsEmpty[func() string](nil) }, expected: true}, + {name: "Function non-nil", isEmptyFn: func() bool { return IsEmpty[func() string](func() string { return "not empty" }) }, expected: false}, + // Interface: + {name: "Interface nil", isEmptyFn: func() bool { return IsEmpty[interface{}](nil) }, expected: true}, + {name: "Interface non-nil", isEmptyFn: func() bool { return IsEmpty[interface{}](new(string)) }, expected: false}, + // Map: + {name: "Map nil", isEmptyFn: func() bool { return IsEmpty[map[string]string](nil) }, expected: true}, + {name: "Map empty", isEmptyFn: func() bool { return IsEmpty[map[string]string](map[string]string{}) }, expected: false}, + {name: "Map non-empty", isEmptyFn: func() bool { return IsEmpty[map[string]string](map[string]string{"not empty": "not empty"}) }, expected: false}, + // Channel: + {name: "Channel nil", isEmptyFn: func() bool { return IsEmpty[chan string](nil) }, expected: true}, + {name: "Channel non-nil", isEmptyFn: func() bool { return IsEmpty[chan string](make(chan string)) }, expected: false}, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.isEmptyFn()) + }) + } +} + +func Test_MapSlice(t *testing.T) { + testCases := []struct { + name string + prepareMapSliceFn func() interface{} + wantMapped interface{} + }{ + { + name: "map to string slice to uppercased string slice", + prepareMapSliceFn: func() interface{} { + return MapSlice([]string{"a", "b", "c"}, strings.ToUpper) + }, + wantMapped: []string{"A", "B", "C"}, + }, + { + name: "map int slice to string slice", + prepareMapSliceFn: func() interface{} { + return MapSlice([]int{1, 2, 3}, func(input int) string { return fmt.Sprintf("%d", input) }) + }, + wantMapped: []string{"1", "2", "3"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotMapped := tc.prepareMapSliceFn() + require.Equal(t, tc.wantMapped, gotMapped) + }) + } +} diff --git a/internal/utils/validation.go b/internal/utils/validation.go new file mode 100644 index 000000000..c19e84e68 --- /dev/null +++ b/internal/utils/validation.go @@ -0,0 +1,91 @@ +package utils + +import ( + "fmt" + "regexp" + "strconv" + + "github.com/asaskevich/govalidator" + "github.com/nyaruka/phonenumbers" +) + +var ( + // RxPhone is a regex used to validate phone number, according with the E.164 standard https://en.wikipedia.org/wiki/E.164 + rxPhone = regexp.MustCompile(`^\+[1-9]{1}[0-9]{9,14}$`) + rxOTP = regexp.MustCompile(`^\d{6}$`) + ErrInvalidE164PhoneNumber = fmt.Errorf("the provided phone number is not a valid E.164 number") +) + +// https://github.com/firebase/firebase-admin-go/blob/cef91acd46f2fc5d0b3408d8154a0005db5bdb0b/auth/user_mgt.go#L449-L457 +func ValidatePhoneNumber(phoneNumberStr string) error { + if phoneNumberStr == "" { + return fmt.Errorf("phone number cannot be empty") + } + + if !rxPhone.MatchString(phoneNumberStr) { + return ErrInvalidE164PhoneNumber + } + + parsedNumber, err := phonenumbers.Parse(phoneNumberStr, "") + if err != nil || !phonenumbers.IsValidNumber(parsedNumber) { + // Parsing error, not a valid phone number + return ErrInvalidE164PhoneNumber + } + + return nil +} + +func ValidateAmount(amount string) error { + if amount == "" { + return fmt.Errorf("amount cannot be empty") + } + + value, err := strconv.ParseFloat(amount, 64) + if err != nil { + return fmt.Errorf("the provided amount is not a valid number") + } + + if value <= 0 { + return fmt.Errorf("the provided amount must be greater than zero") + } + + return nil +} + +// RxEmail is a regex used to validate e-mail addresses, according with the reference https://www.alexedwards.net/blog/validation-snippets-for-go#email-validation. +// It's free to use under the [MIT Licence](https://opensource.org/licenses/MIT) +var rxEmail = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") + +func ValidateEmail(email string) error { + if email == "" { + return fmt.Errorf("email cannot be empty") + } + + if !rxEmail.MatchString(email) { + return fmt.Errorf("the provided email is not valid") + } + + return nil +} + +// IsDNSName will validate the given string as a DNS name +func ValidateDNS(domain string) error { + isDNS := govalidator.IsDNSName(domain) + if !isDNS { + return fmt.Errorf("%q is not a valid DNS name", domain) + } + + return nil +} + +func ValidateOTP(otp string) error { + if otp == "" { + return fmt.Errorf("otp cannot be empty") + } + + if !rxOTP.MatchString(otp) { + return fmt.Errorf("the provided OTP is not a valid 6 digits value") + } + + return nil +} diff --git a/internal/utils/validation_test.go b/internal/utils/validation_test.go new file mode 100644 index 000000000..e8e0841df --- /dev/null +++ b/internal/utils/validation_test.go @@ -0,0 +1,141 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ValidatePhoneNumber(t *testing.T) { + testCases := []struct { + phoneNumber string + wantErr error + }{ + {"", fmt.Errorf("phone number cannot be empty")}, + {"notvalidphone", ErrInvalidE164PhoneNumber}, + {"14155555555", ErrInvalidE164PhoneNumber}, + {"+380445555555", nil}, + {"+14155555555x4444", ErrInvalidE164PhoneNumber}, + {"+1 415 555 5555", ErrInvalidE164PhoneNumber}, + {"+1 415-555-5555", ErrInvalidE164PhoneNumber}, + {"+05555555555", ErrInvalidE164PhoneNumber}, + {"++5555555555", ErrInvalidE164PhoneNumber}, + {"+38012345678", ErrInvalidE164PhoneNumber}, + {"+38056789013", ErrInvalidE164PhoneNumber}, + {"+38034567890", ErrInvalidE164PhoneNumber}, + {"+15555555555", ErrInvalidE164PhoneNumber}, + {"+14155555555", nil}, + } + + for _, tc := range testCases { + t.Run(tc.phoneNumber, func(t *testing.T) { + gotError := ValidatePhoneNumber(tc.phoneNumber) + assert.Equalf(t, tc.wantErr, gotError, "ValidatePhoneNumber(%q) should be %v, but got %v", tc.phoneNumber, tc.wantErr, gotError) + }) + } +} + +func Test_ValidateAmount(t *testing.T) { + testCases := []struct { + amount string + wantErr error + }{ + {"", fmt.Errorf("amount cannot be empty")}, + {"notvalidamount", fmt.Errorf("the provided amount is not a valid number")}, + {"0", fmt.Errorf("the provided amount must be greater than zero")}, + {"0.00", fmt.Errorf("the provided amount must be greater than zero")}, + {"1", nil}, + {"1.00", nil}, + {"1.01", nil}, + } + + for _, tc := range testCases { + t.Run(tc.amount, func(t *testing.T) { + gotError := ValidateAmount(tc.amount) + assert.Equalf(t, tc.wantErr, gotError, "ValidateAmount(%q) should be %v, but got %v", tc.amount, tc.wantErr, gotError) + }) + } +} + +func Test_ValidateEmail(t *testing.T) { + testCases := []struct { + email string + wantErr error + }{ + {"", fmt.Errorf("email cannot be empty")}, + {"notvalidemail", fmt.Errorf("the provided email is not valid")}, + {"valid@test.com", nil}, + {"valid+email@test.com", nil}, + } + + for _, tc := range testCases { + t.Run(tc.email, func(t *testing.T) { + gotError := ValidateEmail(tc.email) + assert.Equalf(t, tc.wantErr, gotError, "ValidateEmail(%q) should be %v, but got %v", tc.email, tc.wantErr, gotError) + }) + } +} + +func Test_ValidateDNS(t *testing.T) { + testCases := []struct { + url string + wantErr error + }{ + {"localhost", nil}, + {"a.bc", nil}, + {"test.com", nil}, + {"a.b..", fmt.Errorf(`"a.b.." is not a valid DNS name`)}, + {"localhost.local", nil}, + {"localhost.localdomain.intern", nil}, + {"l.local.intern", nil}, + {"ru.link.n.svpncloud.com", nil}, + {"-localhost", fmt.Errorf(`"-localhost" is not a valid DNS name`)}, + {"localhost.-localdomain", fmt.Errorf(`"localhost.-localdomain" is not a valid DNS name`)}, + {"localhost.localdomain.-int", fmt.Errorf(`"localhost.localdomain.-int" is not a valid DNS name`)}, + {"localhost._localdomain", nil}, + {"localhost.localdomain._int", nil}, + {"lΓ–calhost", fmt.Errorf(`"lΓ–calhost" is not a valid DNS name`)}, + {"localhost.lΓ–caldomain", fmt.Errorf(`"localhost.lΓ–caldomain" is not a valid DNS name`)}, + {"localhost.localdomain.ΓΌntern", fmt.Errorf(`"localhost.localdomain.ΓΌntern" is not a valid DNS name`)}, + {"localhost/", fmt.Errorf(`"localhost/" is not a valid DNS name`)}, + {"127.0.0.1", fmt.Errorf(`"127.0.0.1" is not a valid DNS name`)}, + {"[::1]", fmt.Errorf(`"[::1]" is not a valid DNS name`)}, + {"50.50.50.50", fmt.Errorf(`"50.50.50.50" is not a valid DNS name`)}, + {"localhost.localdomain.intern:65535", fmt.Errorf(`"localhost.localdomain.intern:65535" is not a valid DNS name`)}, + {"漒字汉字", fmt.Errorf(`"漒字汉字" is not a valid DNS name`)}, + {"www.jubfvq1v3p38i51622y0dvmdk1mymowjyeu26gbtw9andgynj1gg8z3msb1kl5z6906k846pj3sulm4kiyk82ln5teqj9nsht59opr0cs5ssltx78lfyvml19lfq1wp4usbl0o36cmiykch1vywbttcus1p9yu0669h8fj4ll7a6bmop505908s1m83q2ec2qr9nbvql2589adma3xsq2o38os2z3dmfh2tth4is4ixyfasasasefqwe4t2ub2fz1rme.de", fmt.Errorf(`"www.jubfvq1v3p38i51622y0dvmdk1mymowjyeu26gbtw9andgynj1gg8z3msb1kl5z6906k846pj3sulm4kiyk82ln5teqj9nsht59opr0cs5ssltx78lfyvml19lfq1wp4usbl0o36cmiykch1vywbttcus1p9yu0669h8fj4ll7a6bmop505908s1m83q2ec2qr9nbvql2589adma3xsq2o38os2z3dmfh2tth4is4ixyfasasasefqwe4t2ub2fz1rme.de" is not a valid DNS name`)}, + } + + for _, tc := range testCases { + t.Run(tc.url, func(t *testing.T) { + gotError := ValidateDNS(tc.url) + + if tc.wantErr != nil { + assert.EqualErrorf(t, gotError, tc.wantErr.Error(), "ValidateURL(%q) should be '%v', but got '%v'", tc.url, tc.wantErr, gotError) + } else { + assert.NoError(t, gotError) + } + }) + } +} + +func Test_ValidateOTP(t *testing.T) { + testCases := []struct { + otp string + wantErr error + }{ + {"", fmt.Errorf("otp cannot be empty")}, + {"mock", fmt.Errorf("the provided OTP is not a valid 6 digits value")}, + {"123", fmt.Errorf("the provided OTP is not a valid 6 digits value")}, + {"12mock", fmt.Errorf("the provided OTP is not a valid 6 digits value")}, + {"123456", nil}, + } + + for _, tc := range testCases { + t.Run(tc.otp, func(t *testing.T) { + gotError := ValidateOTP(tc.otp) + assert.Equalf(t, tc.wantErr, gotError, "ValidateOTP(%q) should be %v, but got %v", tc.otp, tc.wantErr, gotError) + }) + } +} diff --git a/main.go b/main.go new file mode 100644 index 000000000..abf45ea0f --- /dev/null +++ b/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "os" + + "github.com/sirupsen/logrus" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/cmd" +) + +// Version is the official version of this application. Whenever it's changed +// here, it also needs to be updated at the `helmchart/Chart.yaml#appVersionβ€œ. +const Version = "0.2.0" + +// GitCommit is populated at build time by +// go build -ldflags "-X main.GitCommit=$GIT_COMMIT" +var GitCommit string + +func main() { + preConfigureLogger() + + rootCmd := cmd.SetupCLI(Version, GitCommit) + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +// preConfigureLogger will set the log level to Trace, so logs works from the +// start. This will eventually be overwritten in cmd/root.go +func preConfigureLogger() { + log.DefaultLogger = log.New() + log.DefaultLogger.SetLevel(logrus.TraceLevel) +} diff --git a/resources/grafana/README.md b/resources/grafana/README.md new file mode 100644 index 000000000..595dbf8c0 --- /dev/null +++ b/resources/grafana/README.md @@ -0,0 +1,119 @@ +# Dashboard Grafana SDPV2 + +This dashboard template shows metrics that are exported in the application using prometheus. Currently, the metrics available in this template are the `HTTP request` and `Database Query` metrics. In the future new metrics will be added including business ones. + +To import the file `dashboard.json` into a grafana instance you need to follow [these steps](https://grafana.com/docs/grafana/latest/dashboards/manage-dashboards/#import-a-dashboard). When importing the dashboard, you need to add the prometheus instance that is scraping the application metrics as a data source in grafana. More information in [this link](https://grafana.com/docs/grafana/latest/datasources/prometheus) + +## HTTP Requests Metrics + +At the moment the HTTP request metrics are being monitored by 7 panels, and this panels can be filtered by: + +- Request endpoint +- HTTP method +- Status returned by request +- Instance that is running the app + +### Request HTTP Rate + +Chart responsible for calculating the number of requests per second performed in a 5-minute interval. Can be filtered by route, method, status, and instance. + +![http_request_rate](./images/http_request_rate.png) + +### Request HTTP Average Time + +Chart responsible for calculating the average time in ms that a request takes to be executed in a 5-minute interval. Can be filtered by route, method, status, and instance. + +![http_request_average_time](./images/http_request_average_time.png) + +### Request HTTP Rate Aggregate By Instance + +Chart responsible for calculating the number of requests per second performed in a 5-minute interval aggregated by the average of all instances. Can be filtered by route, method, and status. + +![http_request_rate_aggregate_by_instance](./images/http_request_rate_aggregate_by_instance.png) + +### Request HTTP Average Time Aggregate By Instance + +Chart responsible for calculating the average time in ms that a request takes to be executed in a 5-minute interval aggregated by the average of all instances. Can be filtered by route, method, and status. + +![http_request_average_time_aggregate_by_instance ](./images/http_request_average_time_aggregate_by_instance.png) + +### HTTP Request Quantiles + +Graph responsible for showing the duration quantiles of http requests calculated by Prometheus. Can be filtered by route, method, status, and instance. + +![http_request_quantile](./images/http_request_quantile.png) + +### Total and Average HTTP Requests + +Panels responsible for showing the aggregate sum of requests performed by the application since it started running. And also the average time these requests take to execute aggregated by the average. Can be filtered by route, method, status, and instance. + +![total_requests_and_average_time](./images/total_requests_and_average_time.png) + +## DB Query Metrics + +At the moment the DB Query metrics are being monitored by 10 panels, and this panels can be filtered by: + +- Query type (Select, Update, Insert, Create) +- Instance that is running the app + +The database metrics are separated into successful and failure queries. + +### Successful DB Query Rate + +Chart responsible for calculating the number of successful queries per second performed in a 5-minute interval. Can be filtered by query_type and instance. + +![successful_query_rate](./images/successful_query_rate.png) + +### Failure DB Query Rate + +Chart responsible for calculating the number of failed queries per second performed in a 5-minute interval. Can be filtered by query_type and instance. + +![failure_query_rate](./images/failure_query_rate.png) + +### Successful DB Query Average Time + +Chart responsible for calculating the average time in ms that a successful query takes to be executed in a 5-minute interval. Can be filtered by query_type and instance. + +![avg_successful_query_duration](./images/avg_successful_query_duration.png) + +### Failure DB Query Average Time + +Chart responsible for calculating the average time in ms that a failed query takes to be executed in a 5-minute interval. Can be filtered by query_type and instance. + +![avg_failure_query_duration](./images/avg_failure_query_duration.png) + +### Successful DB Query Rate Aggregate By Instance + +Chart responsible for calculating the number of successful queries per second performed in a 5-minute interval aggregated by the average of all instances. Can be filtered by query_type. + +![successful_query_rate_aggregate_by_instance](./images/successful_query_rate_aggregate_by_instance.png) + +### Failure DB Query Rate Aggregate By Instance + +Chart responsible for calculating the number of failed queries per second performed in a 5-minute interval aggregated by the average of all instances. Can be filtered by query_type. + +![failure_query_rate_aggregate_by_instance](./images/failure_query_rate_aggregate_by_instance.png) + +### Successful DB Query Average Time Aggregate By Instance + +Chart responsible for calculating the average time in ms that a successful query takes to be executed in a 5-minute interval aggregated by the average of all instances. Can be filtered by query_type. + +![avg_successful_query_duration_aggregate_by_instance](./images/avg_successful_query_duration_aggregate_by_instance.png) + +### Failure DB Query Average Time Aggregate By Instance + +Chart responsible for calculating the average time in ms that a failed query takes to be executed in a 5-minute interval aggregated by the average of all instances. Can be filtered by query_type. + +![avg_failure_query_duration_aggregate_by_instance](./images/avg_failure_query_duration_aggregate_by_instance.png) + +### Successful DB Query Quantiles + +Graph responsible for showing the duration quantiles of successful queries calculated by Prometheus. Can be filtered by query_type and instance. + +![successful_query_quantiles](./images/successful_query_quantiles.png) + +### Failure DB Query Quantiles + +Graph responsible for showing the duration quantiles of failed queries calculated by Prometheus. Can be filtered by query_type and instance. + +![failure_query_quantiles](./images/failure_query_quantiles.png) diff --git a/resources/grafana/dashboard.json b/resources/grafana/dashboard.json new file mode 100644 index 000000000..dd0e91791 --- /dev/null +++ b/resources/grafana/dashboard.json @@ -0,0 +1,1755 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 2, + "links": [], + "liveNow": false, + "panels": [ + { + "collapsed": true, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 2, + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Rate request per second in a 5 min interval. Filtered by route, method, status and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "req/s", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [ + { + "__systemRef": "hideSeriesFrom", + "matcher": { + "id": "byNames", + "options": { + "mode": "exclude", + "names": [ + "GET /health: 200 10.244.0.232:8002" + ], + "prefix": "All except:", + "readOnly": true + } + }, + "properties": [ + { + "id": "custom.hideFrom", + "value": { + "legend": false, + "tooltip": false, + "viz": true + } + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 1 + }, + "id": 4, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "rate(sdp_http_requests_duration_seconds_count{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}[5m])", + "legendFormat": "{{method}} {{route}}: {{status}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Request Rate [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average request time duration in a 5 min interval. Filtered by route, method, status and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 1 + }, + "id": 6, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * (rate(sdp_http_requests_duration_seconds_sum{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}[5m]) / rate(sdp_http_requests_duration_seconds_count{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}[5m]))", + "legendFormat": "{{method}} {{route}}: {{status}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Avg Request Duration [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average rate request per second in a 5 min interval aggregate by instance. Filtered by route, method, status.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "req/s", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 9 + }, + "id": 5, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "avg(rate(sdp_http_requests_duration_seconds_count{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}[5m])) by (route, method, status)", + "legendFormat": "{{method}} {{route}}: {{status}}", + "range": true, + "refId": "A" + } + ], + "title": "Request Rate Aggregate by Instance [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average request time duration in a 5 min interval aggregate by instances. Filtered by route, method, status.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 9 + }, + "id": 7, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * avg((rate(sdp_http_requests_duration_seconds_sum{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}[5m]) / rate(sdp_http_requests_duration_seconds_count{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}[5m]))) by (route, method, status)", + "legendFormat": "{{method}} {{route}}: {{status}}", + "range": true, + "refId": "A" + } + ], + "title": "Avg Request Duration Aggregate by instances [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [ + { + "__systemRef": "hideSeriesFrom", + "matcher": { + "id": "byNames", + "options": { + "mode": "exclude", + "names": [ + "q0.5 GET /health: 200 10.244.0.232:8002", + "q0.5 GET /health: 200 10.244.0.234:8002" + ], + "prefix": "All except:", + "readOnly": true + } + }, + "properties": [ + { + "id": "custom.hideFrom", + "value": { + "legend": false, + "tooltip": false, + "viz": true + } + } + ] + } + ] + }, + "gridPos": { + "h": 12, + "w": 24, + "x": 0, + "y": 17 + }, + "id": 9, + "interval": "1m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * sdp_http_requests_duration_seconds{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"}", + "legendFormat": "q{{quantile}} {{method}} {{route}}: {{status}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Http Request Quantile", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Total number of requests filtered by route, method, status and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 29 + }, + "id": 11, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "9.3.8", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "sum(sdp_http_requests_duration_seconds_count{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"})", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Total Requests", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Total average time duration from a request. Filtered by route, method, status and instance", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 29 + }, + "id": 12, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "9.3.8", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * avg(sdp_http_requests_duration_seconds_sum{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"} / sdp_http_requests_duration_seconds_count{route=~\"$route\", method=~\"$method\", status=~\"$status\", instance=~\"$instance\"})", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Total Avg Request Time Duration [ms]", + "type": "stat" + } + ], + "title": "Http Metrics", + "type": "row" + }, + { + "collapsed": true, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 1 + }, + "id": 14, + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Rate successful query per second in a 5 min interval. Filtered by query_type and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "query/s", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 2 + }, + "id": 15, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "rate(sdp_db_successful_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m])", + "legendFormat": "{{query_type}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Successful Query Rate [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average successful query duration in a 5 min interval. Filtered by query_type and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 2 + }, + "id": 19, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * (rate(sdp_db_successful_queries_duration_sum{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]) / rate(sdp_db_successful_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]))", + "legendFormat": "{{query_type}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Avg Successful Query Duration [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average rate successful query per second in a 5 min interval. Filtered by query_type.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "query/s", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 10 + }, + "id": 17, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "avg(rate(sdp_db_successful_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m])) by (query_type)", + "legendFormat": "{{query_type}}", + "range": true, + "refId": "A" + } + ], + "title": "Successful Query Rate Aggregate By Instances [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average successful query duration in a 5 min interval aggregate by instances. Filtered by query_type.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 10 + }, + "id": 20, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "avg(1000 * (rate(sdp_db_successful_queries_duration_sum{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]) / rate(sdp_db_successful_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]))) by (query_type)", + "legendFormat": "{{query_type}}", + "range": true, + "refId": "A" + } + ], + "title": "Avg Succesful Query Duration Aggregate By Instances[5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 9, + "w": 24, + "x": 0, + "y": 18 + }, + "id": 24, + "interval": "1m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * sdp_db_successful_queries_duration{query_type=~\"$query_type\", instance=~\"$instance\"}", + "legendFormat": "q{{quantile}} {{query_type}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Successful Query Quantiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Rate failure query per second in a 5 min interval. Filtered by query_type and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "query/s", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 27 + }, + "id": 16, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "rate(sdp_db_failure_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m])", + "legendFormat": "{{query_type}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Failure Query Rate [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average failure query duration in a 5 min interval. Filtered by query_type and instance.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 27 + }, + "id": 21, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * (rate(sdp_db_failure_queries_duration_sum{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]) / rate(sdp_db_failure_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]))", + "legendFormat": "{{query_type}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Avg Failure Query Duration [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average rate failure query per second in a 5 min interval. Filtered by query_type.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "query/s", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 35 + }, + "id": 18, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "avg(rate(sdp_db_failure_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m])) by (query_type)", + "legendFormat": "{{query_type}}", + "range": true, + "refId": "A" + } + ], + "title": "Failure Query Rate Aggregate By Instances [5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "description": "Average failure query duration in a 5 min interval aggregate by instances. Filtered by query_type.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 3, + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 35 + }, + "id": 22, + "interval": "5m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "avg(1000 * (rate(sdp_db_failure_queries_duration_sum{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]) / rate(sdp_db_failure_queries_duration_count{query_type=~\"$query_type\", instance=~\"$instance\"}[5m]))) by (query_type)", + "legendFormat": "{{query_type}}", + "range": true, + "refId": "A" + } + ], + "title": "Avg Failure Query Duration Aggregate By Instances[5m]", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "milliseconds", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 5, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 7, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 9, + "w": 24, + "x": 0, + "y": 43 + }, + "id": 25, + "interval": "1m", + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "editorMode": "code", + "expr": "1000 * sdp_db_failure_queries_duration{query_type=~\"$query_type\", instance=~\"$instance\"}", + "legendFormat": "q{{quantile}} {{query_type}} {{instance}}", + "range": true, + "refId": "A" + } + ], + "title": "Failure Query Quantiles", + "type": "timeseries" + } + ], + "title": "DB Metrics", + "type": "row" + } + ], + "schemaVersion": 37, + "style": "dark", + "tags": [], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "definition": "query_result(sdp_http_requests_duration_seconds_count)", + "description": "", + "hide": 0, + "includeAll": true, + "multi": true, + "name": "route", + "options": [], + "query": { + "query": "query_result(sdp_http_requests_duration_seconds_count)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "/.*route=\"([^\"]+).*/", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": true, + "text": [ + "GET" + ], + "value": [ + "GET" + ] + }, + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "definition": "query_result(sdp_http_requests_duration_seconds_count)", + "hide": 0, + "includeAll": true, + "multi": true, + "name": "method", + "options": [], + "query": { + "query": "query_result(sdp_http_requests_duration_seconds_count)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "/.*method=\"([^\"]+).*/", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": true, + "text": [ + "200" + ], + "value": [ + "200" + ] + }, + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "definition": "query_result(sdp_http_requests_duration_seconds_count)", + "hide": 0, + "includeAll": true, + "multi": true, + "name": "status", + "options": [], + "query": { + "query": "query_result(sdp_http_requests_duration_seconds_count)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "/.*status=\"([^\"]+).*/", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "definition": "query_result(sdp_http_requests_duration_seconds_count)", + "hide": 0, + "includeAll": true, + "label": "instance", + "multi": true, + "name": "instance", + "options": [], + "query": { + "query": "query_result(sdp_http_requests_duration_seconds_count)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "/.*instance=\"([^\"]+).*/", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": false, + "text": "All", + "value": "$__all" + }, + "datasource": { + "type": "prometheus", + "uid": "NJaP9W-4z" + }, + "definition": "query_result(sdp_db_failure_queries_duration_count or sdp_db_successful_queries_duration_count)", + "hide": 0, + "includeAll": true, + "label": "query_type", + "multi": true, + "name": "query_type", + "options": [], + "query": { + "query": "query_result(sdp_db_failure_queries_duration_count or sdp_db_successful_queries_duration_count)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "regex": "/.*query_type=\"([^\"]+).*/", + "skipUrlSync": false, + "sort": 0, + "type": "query" + } + ] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Dashboard SDPV2", + "uid": "XIR0jW-4k", + "version": 40, + "weekStart": "" +} \ No newline at end of file diff --git a/resources/grafana/images/avg_failure_query_duration.png b/resources/grafana/images/avg_failure_query_duration.png new file mode 100644 index 000000000..ef85dbba7 Binary files /dev/null and b/resources/grafana/images/avg_failure_query_duration.png differ diff --git a/resources/grafana/images/avg_failure_query_duration_aggregate_by_instance.png b/resources/grafana/images/avg_failure_query_duration_aggregate_by_instance.png new file mode 100644 index 000000000..cb5c3c4b5 Binary files /dev/null and b/resources/grafana/images/avg_failure_query_duration_aggregate_by_instance.png differ diff --git a/resources/grafana/images/avg_successful_query_duration.png b/resources/grafana/images/avg_successful_query_duration.png new file mode 100644 index 000000000..292690045 Binary files /dev/null and b/resources/grafana/images/avg_successful_query_duration.png differ diff --git a/resources/grafana/images/avg_successful_query_duration_aggregate_by_instance.png b/resources/grafana/images/avg_successful_query_duration_aggregate_by_instance.png new file mode 100644 index 000000000..4a9428328 Binary files /dev/null and b/resources/grafana/images/avg_successful_query_duration_aggregate_by_instance.png differ diff --git a/resources/grafana/images/failure_query_quantiles.png b/resources/grafana/images/failure_query_quantiles.png new file mode 100644 index 000000000..9168808c2 Binary files /dev/null and b/resources/grafana/images/failure_query_quantiles.png differ diff --git a/resources/grafana/images/failure_query_rate.png b/resources/grafana/images/failure_query_rate.png new file mode 100644 index 000000000..36d99bbaa Binary files /dev/null and b/resources/grafana/images/failure_query_rate.png differ diff --git a/resources/grafana/images/failure_query_rate_aggregate_by_instance.png b/resources/grafana/images/failure_query_rate_aggregate_by_instance.png new file mode 100644 index 000000000..87fe8833f Binary files /dev/null and b/resources/grafana/images/failure_query_rate_aggregate_by_instance.png differ diff --git a/resources/grafana/images/http_request_average_time.png b/resources/grafana/images/http_request_average_time.png new file mode 100644 index 000000000..53a475396 Binary files /dev/null and b/resources/grafana/images/http_request_average_time.png differ diff --git a/resources/grafana/images/http_request_average_time_aggregate_by_instance.png b/resources/grafana/images/http_request_average_time_aggregate_by_instance.png new file mode 100644 index 000000000..14157efc5 Binary files /dev/null and b/resources/grafana/images/http_request_average_time_aggregate_by_instance.png differ diff --git a/resources/grafana/images/http_request_quantile.png b/resources/grafana/images/http_request_quantile.png new file mode 100644 index 000000000..e7b7650c8 Binary files /dev/null and b/resources/grafana/images/http_request_quantile.png differ diff --git a/resources/grafana/images/http_request_rate.png b/resources/grafana/images/http_request_rate.png new file mode 100644 index 000000000..05db836c1 Binary files /dev/null and b/resources/grafana/images/http_request_rate.png differ diff --git a/resources/grafana/images/http_request_rate_aggregate_by_instance.png b/resources/grafana/images/http_request_rate_aggregate_by_instance.png new file mode 100644 index 000000000..1a9a147a1 Binary files /dev/null and b/resources/grafana/images/http_request_rate_aggregate_by_instance.png differ diff --git a/resources/grafana/images/successful_query_quantiles.png b/resources/grafana/images/successful_query_quantiles.png new file mode 100644 index 000000000..7400e9b09 Binary files /dev/null and b/resources/grafana/images/successful_query_quantiles.png differ diff --git a/resources/grafana/images/successful_query_rate.png b/resources/grafana/images/successful_query_rate.png new file mode 100644 index 000000000..9450afba3 Binary files /dev/null and b/resources/grafana/images/successful_query_rate.png differ diff --git a/resources/grafana/images/successful_query_rate_aggregate_by_instance.png b/resources/grafana/images/successful_query_rate_aggregate_by_instance.png new file mode 100644 index 000000000..205488d44 Binary files /dev/null and b/resources/grafana/images/successful_query_rate_aggregate_by_instance.png differ diff --git a/resources/grafana/images/total_requests_and_average_time.png b/resources/grafana/images/total_requests_and_average_time.png new file mode 100644 index 000000000..eba8d263f Binary files /dev/null and b/resources/grafana/images/total_requests_and_average_time.png differ diff --git a/resources/grafana/transaction_submission_service_dashboard.json b/resources/grafana/transaction_submission_service_dashboard.json new file mode 100644 index 000000000..ae95ace85 --- /dev/null +++ b/resources/grafana/transaction_submission_service_dashboard.json @@ -0,0 +1,1123 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 1, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 6, + "w": 5, + "x": 0, + "y": 0 + }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "tss_tx_processing_processed_count{result=\"success\"}", + "refId": "A" + } + ], + "title": "Transactions Processed (Success)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 6, + "w": 5, + "x": 5, + "y": 0 + }, + "id": 19, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "tss_tx_processing_processed_count{result=\"error\"}", + "refId": "A" + } + ], + "title": "Transactions Processed (Errors)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 6, + "w": 5, + "x": 10, + "y": 0 + }, + "id": 2, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "tss_log_error_total", + "refId": "A" + } + ], + "title": "Log (Errors)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 6, + "w": 5, + "x": 15, + "y": 0 + }, + "id": 18, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "tss_log_error_total", + "refId": "A" + } + ], + "title": "Log (Errors)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 6, + "w": 4, + "x": 20, + "y": 0 + }, + "id": 4, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "tss_log_warn_total", + "refId": "A" + } + ], + "title": "(TODO) - Horizon (Errors)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 12, + "x": 0, + "y": 6 + }, + "id": 12, + "options": { + "legend": { + "calcs": [ + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom" + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": " sum without (instance)(rate(tss_tx_processing_queued_to_completed_latency_seconds_sum{result=\"success\"}[$__rate_interval\n]))\n/\n sum without (instance)(rate(tss_tx_processing_queued_to_completed_latency_seconds_count{result=\"success\"}[$__rate_interval\n]))", + "hide": false, + "legendFormat": "latency (since createAt) - {{result}}", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "exemplar": false, + "expr": " sum without (instance)(rate(tss_tx_processing_started_to_completed_latency_seconds_sum{result=\"success\"}[$__rate_interval\n]))\n/\n sum without (instance)(rate(tss_tx_processing_started_to_completed_latency_seconds_count{result=\"success\"}[$__rate_interval\n]))", + "format": "time_series", + "hide": false, + "instant": false, + "interval": "", + "legendFormat": "latency (since startedAt) - {{result}}", + "range": true, + "refId": "B" + } + ], + "title": "Transaction Latency", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "none" + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 12, + "x": 12, + "y": 6 + }, + "id": 22, + "options": { + "legend": { + "calcs": [ + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom" + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(tss_tx_processing_retry_count_bucket[$__rate_interval])) by (le))", + "legendFormat": "99th", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum(rate(tss_tx_processing_retry_count_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "95th", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum(rate(tss_tx_processing_retry_count_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "90th", + "range": true, + "refId": "C" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(tss_tx_processing_retry_count_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "50th", + "range": true, + "refId": "D" + } + ], + "title": "Transactions Retried (Attempts) Quantiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 12, + "x": 0, + "y": 16 + }, + "id": 21, + "options": { + "legend": { + "calcs": [ + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom" + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(tss_tx_processing_queued_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "legendFormat": "99th", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum(rate(tss_tx_processing_queued_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "95th", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum(rate(tss_tx_processing_queued_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "90th", + "range": true, + "refId": "C" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(tss_tx_processing_queued_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "50th", + "range": true, + "refId": "E" + } + ], + "title": "Transaction Latency (Queued to Completed) Quantiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 12, + "x": 12, + "y": 16 + }, + "id": 23, + "options": { + "legend": { + "calcs": [ + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom" + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(tss_tx_processing_started_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "legendFormat": "99th", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum(rate(tss_tx_processing_started_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "95th", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum(rate(tss_tx_processing_started_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "90th", + "range": true, + "refId": "C" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(tss_tx_processing_started_to_completed_latency_seconds_bucket[$__rate_interval])) by (le))", + "hide": false, + "legendFormat": "50th", + "range": true, + "refId": "E" + } + ], + "title": "Transaction Latency (Started to Completed) Quantiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + } + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 12, + "w": 12, + "x": 0, + "y": 26 + }, + "id": 16, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + }, + "pieType": "pie", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "title": "(TODO) - Transaction Submission Error Distribution", + "type": "piechart" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + } + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 12, + "w": 12, + "x": 12, + "y": 26 + }, + "id": 17, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + }, + "pieType": "pie", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "title": "(TODO) - Horizon Error Distribution", + "type": "piechart" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "decbytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 12, + "x": 0, + "y": 38 + }, + "id": 6, + "options": { + "legend": { + "calcs": [ + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom" + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "go_memstats_alloc_bytes", + "refId": "A" + } + ], + "title": "Memory (MB)", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 12, + "x": 12, + "y": 38 + }, + "id": 8, + "options": { + "legend": { + "calcs": [ + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom" + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "8.5.5", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "KIYiIaUVz" + }, + "expr": "go_goroutines", + "refId": "A" + } + ], + "title": "Goroutine Count", + "type": "timeseries" + } + ], + "refresh": "5s", + "schemaVersion": 36, + "style": "dark", + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "now-1h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Transaction Submission Service Metrics", + "uid": "lSoCpU84z", + "version": 4, + "weekStart": "" +} \ No newline at end of file diff --git a/scripts/exclude_from_coverage.sh b/scripts/exclude_from_coverage.sh new file mode 100755 index 000000000..d21d6098b --- /dev/null +++ b/scripts/exclude_from_coverage.sh @@ -0,0 +1,7 @@ +#!/bin/sh +while IFS= read -r p || [ -n "$p" ]; do + exp=".*${p}.*" + sed -i "/${exp}/d" ./c.out +done << EOF # list of terms and files we want to exclude +mocks +EOF diff --git a/stellar-auth/README.md b/stellar-auth/README.md new file mode 100644 index 000000000..218c046a7 --- /dev/null +++ b/stellar-auth/README.md @@ -0,0 +1,256 @@ +# Stellar Auth + +Stellar Auth is a package that provides authentication functionality for Stellar applications. It simplifies the process of managing user authentication. + +## Table of Contents + +- [CLI](#cli) + - [add-user](#add-user) + - [roles](#roles) +- [Usage](#usage) + +## CLI + +The Stellar Auth provides a CLI that helps adding new users and applying the database migrations in order to create all necessary tables. + +```sh +Stellar Auth handles JWT management. + +Usage: + stellarauth [flags] + stellarauth [command] + +Available Commands: + add-user Add user to the system + completion Generate the autocompletion script for the specified shell + help Help about any command + migrate Apply Stellar Auth database migrations + +Flags: + --database-url string Postgres DB URL (DATABASE_URL) (default "postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable") + -h, --help help for stellarauth + --log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE") + +Use "stellarauth [command] --help" for more information about a command. +``` + +### add-user + +To add a new user using the CLI you can use the `add-user` subcommand. + +```sh +$ stellarauth add-user --help + +Usage: + stellarauth add-user [--owner] [--roles] [--password] [flags] + +Flags: + -h, --help help for add-user + --owner Set the user as Owner (superuser). Defaults to "false". (OWNER) + --password Sets the user password, it should be at least 8 characters long, if omitted, the command will generate a random one. (PASSWORD) + +Global Flags: + --database-url string Postgres DB URL (DATABASE_URL) (default "postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable") + --log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE") +``` + +When creating a new user you can set the password. + +```sh +$ export DATABASE_URL=postgres://... # Or you can specify in the command --database-url postgres://.. +$ stellarauth migrate up # Creating the necessary tables +$ stellarauth add-user mary.jane@stellar.org Mary Jane --password + +INFO[2023-07-31T17:05:46.292-03:00] Version: 0.2.0 pid=22464 +INFO[2023-07-31T17:05:46.292-03:00] GitCommit: pid=22464 +Password: +INFO[2023-07-31T17:05:55.159-03:00] user inserted: mary.jane@stellar.org pid=22464 +``` + +### roles + +You can add role management by passing the available roles to the `AddUserCmd`. After this the flag `--roles` will show up in the `add-user` subcommand. + +```go +// pkg/cli/root.go + +func SetupCLI(version, gitCommit string) *cobra.Command { + // ... + + cmd.AddCommand(AddUserCmd("", NewDefaultPasswordPrompt(), []string{"approver", "editor", "owner"})) + + return cmd +} + +``` + +```sh +$ stellarauth add-user --help + +Usage: + stellarauth add-user [--owner] [--roles] [--password] [flags] + +Flags: + -h, --help help for add-user + --owner Set the user as Owner (superuser). Defaults to "false". (OWNER) + --password Sets the user password, it should be at least 8 characters long, if omitted, the command will generate a random one. (PASSWORD) + --roles string Set the user roles. It should be comma separated. Example: role1, role2. Available roles: [approver, editor, owner]. (ROLES) + +Global Flags: + --database-url string Postgres DB URL (DATABASE_URL) (default "postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable") + --log-level string The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC". (LOG_LEVEL) (default "TRACE") +``` + +```sh +$ stellarauth add-user mary.jane@stellar.org Mary Jane --roles approver,editor --password + +INFO[2023-07-31T17:05:46.292-03:00] Version: 0.2.0 pid=22464 +INFO[2023-07-31T17:05:46.292-03:00] GitCommit: pid=22464 +Password: +INFO[2023-07-31T17:05:55.159-03:00] user inserted: mary.jane@stellar.org pid=22464 +``` + +## Usage + +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "time" + + authdb "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +var AuthManager auth.AuthManager + +type LoginRequest struct { + Email string `json:"email"` + Password string `json:"password"` +} + +func main() { + mux := http.NewServeMux() + + databaseURL := os.Getenv("DATABASE_URL") + dbConnectionPool, err := authdb.OpenDBConnectionPool(databaseURL) + if err != nil { + log.Fatal(err) + } + + // Instantiating AuthManager using the default options + AuthManager = auth.NewAuthManager( + auth.WithDefaultAuthenticatorOption(dbConnectionPool, auth.NewDefaultPasswordEncrypter(), time.Hour*1), + auth.WithDefaultJWTManagerOption(os.Getenv("EC256_PUBLIC_KEY"), os.Getenv("DATABASE_URL")), + auth.WithDefaultRoleManagerOption(dbConnectionPool, "owner"), + ) + + mux.HandleFunc("/login", login) + mux.HandleFunc("/refresh-token", refreshToken) + mux.Handle("/authenticated", AuthenticatedMiddleware(http.HandlerFunc(myAuthenticatedHandler))) + mux.Handle("/role-required", AuthenticatedMiddleware( + RoleMiddleware([]string{"myRole1", "myRole2"})(http.HandlerFunc(myRoleRequiredHandler)), + )) + + http.ListenAndServe(":8000", mux) +} + +func myAuthenticatedHandler(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`Ok`)) +} + +func myRoleRequiredHandler(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`Ok`)) +} + +func login(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var reqBody LoginRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "invalid request body"}`)) + } + + token, err := AuthManager.Authenticate(ctx, reqBody.Email, reqBody.Password) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Write([]byte(fmt.Sprintf(`{"token": %q}`, token))) +} + +func refreshToken(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token := r.Header.Get("Authorization") + + token, err := AuthManager.RefreshToken(ctx, token) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Write([]byte(fmt.Sprintf(`{"token": %q}`, token))) +} + +func AuthenticatedMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token := r.Header.Get("Authorization") + + // Does the header validation... + + isValid, err := AuthManager.ValidateToken(ctx, token) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "not authorized"}`)) + return + } + + if !isValid { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "not authorized"}`)) + return + } + + // Additionally you can add the token to the request context + ctx = context.WithValue(ctx, "tokenKey", token) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) +} + +func RoleMiddleware(requiredRoles []string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token := r.Header.Get("Authorization") + + // Does the header validation... + + hasAnyRoles, err := AuthManager.AnyRolesInTokenUser(ctx, token, requiredRoles) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "not authorized"}`)) + return + } + + if !hasAnyRoles { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "not authorized"}`)) + return + } + + next.ServeHTTP(w, r) + }) + } +} +``` diff --git a/stellar-auth/cmd/stellarauth/main.go b/stellar-auth/cmd/stellarauth/main.go new file mode 100644 index 000000000..67fb923d5 --- /dev/null +++ b/stellar-auth/cmd/stellarauth/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "github.com/sirupsen/logrus" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/cli" +) + +// Version is the official version of this application. +const Version = "0.2.0" + +// GitCommit is populated at build time by +// go build -ldflags "-X main.GitCommit=$GIT_COMMIT" +var GitCommit string + +func main() { + log.DefaultLogger = log.New() + log.DefaultLogger.SetLevel(logrus.TraceLevel) + + cmd := cli.SetupCLI(Version, GitCommit) + if err := cmd.Execute(); err != nil { + log.Fatalf("error executing: %s", err.Error()) + } +} diff --git a/stellar-auth/internal/db/db.go b/stellar-auth/internal/db/db.go new file mode 100644 index 000000000..adce8beb1 --- /dev/null +++ b/stellar-auth/internal/db/db.go @@ -0,0 +1,151 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/jmoiron/sqlx" + "github.com/stellar/go/support/log" +) + +const ( + MaxDBConnIdleTime = 10 * time.Second + MaxOpenDBConns = 30 +) + +// DBConnectionPoolFromSqlDB returns a new DBConnectionPool wrapper for a PRE-EXISTING *sql.DB. The driverName of the +// original database is required for named query support. ATTENTION: this will not start a new connection pool, just +// create a wrap aroung the pre-existing connection pool. +func DBConnectionPoolFromSqlDB(sqlDB *sql.DB, driverName string) DBConnectionPool { + return &DBConnectionPoolImplementation{DB: sqlx.NewDb(sqlDB, driverName)} +} + +// DBConnectionPool is an interface that wraps the sqlx.DB structs methods and includes the RunInTransaction helper. +type DBConnectionPool interface { + SQLExecuter + BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) + Close() error + Ping() error + SqlDB() *sql.DB + SqlxDB() *sqlx.DB +} + +// DBConnectionPoolImplementation is a wrapper around sqlx.DB that implements DBConnectionPool. +type DBConnectionPoolImplementation struct { + *sqlx.DB +} + +func (db *DBConnectionPoolImplementation) BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) { + return db.DB.BeginTxx(ctx, opts) +} + +func (db *DBConnectionPoolImplementation) SqlDB() *sql.DB { + return db.DB.DB +} + +func (db *DBConnectionPoolImplementation) SqlxDB() *sqlx.DB { + return db.DB +} + +// RunInTransactionWithResult runs the given atomic function in an atomic database transaction and returns a result and +// an error. Boilerplate code for database transactions. +func RunInTransactionWithResult[T any](ctx context.Context, dbConnectionPool DBConnectionPool, opts *sql.TxOptions, atomicFunction func(dbTx DBTransaction) (T, error)) (result T, err error) { + dbTx, err := dbConnectionPool.BeginTxx(ctx, opts) + if err != nil { + return *new(T), fmt.Errorf("creating db transaction for RunInTransactionWithResult: %w", err) + } + + defer func() { + DBTxRollback(ctx, dbTx, err, "rolling back transaction due to error") + }() + + result, err = atomicFunction(dbTx) + if err != nil { + return *new(T), fmt.Errorf("running atomic function in RunInTransactionWithResult: %w", err) + } + + err = dbTx.Commit() + if err != nil { + return *new(T), fmt.Errorf("committing transaction in RunInTransactionWithResult: %w", err) + } + + return result, nil +} + +// RunInTransaction runs the given atomic function in an atomic database transaction and returns an error. Boilerplate +// code for database transactions. +func RunInTransaction(ctx context.Context, dbConnectionPool DBConnectionPool, opts *sql.TxOptions, atomicFunction func(dbTx DBTransaction) error) error { + // wrap the atomic function with a function that returns nil and an error so we can call RunInTransactionWithResult + wrappedFunction := func(dbTx DBTransaction) (interface{}, error) { + return nil, atomicFunction(dbTx) + } + + _, err := RunInTransactionWithResult(ctx, dbConnectionPool, opts, wrappedFunction) + return err +} + +// make sure *DBConnectionPoolImplementation implements DBConnectionPool: +var _ DBConnectionPool = (*DBConnectionPoolImplementation)(nil) + +// DBTransaction is an interface that wraps the sqlx.Tx structs methods. +type DBTransaction interface { + SQLExecuter + Rollback() error + Commit() error +} + +// make sure *sqlx.Tx implements DBTransaction: +var _ DBTransaction = (*sqlx.Tx)(nil) + +// SQLExecuter is an interface that wraps the *sqlx.DB and *sqlx.Tx structs methods. +type SQLExecuter interface { + DriverName() string + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + sqlx.PreparerContext + sqlx.QueryerContext + Rebind(query string) string + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error +} + +// make sure *sqlx.DB implements SQLExecuter: +var _ SQLExecuter = (*sqlx.DB)(nil) + +// make sure DBConnectionPool implements SQLExecuter: +var _ SQLExecuter = (DBConnectionPool)(nil) + +// make sure *sqlx.Tx implements SQLExecuter: +var _ SQLExecuter = (*sqlx.Tx)(nil) + +// make sure DBTransaction implements SQLExecuter: +var _ SQLExecuter = (DBTransaction)(nil) + +// DBTxRollback rolls back the transaction if there is an error. +func DBTxRollback(ctx context.Context, dbTx DBTransaction, err error, logMessage string) { + if err != nil { + log.Ctx(ctx).Errorf("%s: %s", logMessage, err.Error()) + errRollBack := dbTx.Rollback() + if errRollBack != nil { + log.Ctx(ctx).Errorf("error in database transaction rollback: %s", errRollBack.Error()) + } + } +} + +// OpenDBConnectionPool opens a new database connection pool. It returns an error if it can't connect to the database. +func OpenDBConnectionPool(dataSourceName string) (DBConnectionPool, error) { + sqlxDB, err := sqlx.Open("postgres", dataSourceName) + if err != nil { + return nil, fmt.Errorf("error creating app DB connection pool: %w", err) + } + sqlxDB.SetConnMaxIdleTime(MaxDBConnIdleTime) + sqlxDB.SetMaxOpenConns(MaxOpenDBConns) + + err = sqlxDB.Ping() + if err != nil { + return nil, fmt.Errorf("error pinging app DB connection pool: %w", err) + } + + return &DBConnectionPoolImplementation{DB: sqlxDB}, nil +} diff --git a/stellar-auth/internal/db/db_test.go b/stellar-auth/internal/db/db_test.go new file mode 100644 index 000000000..a67723e55 --- /dev/null +++ b/stellar-auth/internal/db/db_test.go @@ -0,0 +1,23 @@ +package db + +import ( + "testing" + + "github.com/stellar/go/support/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpen_OpenDBConnectionPool(t *testing.T) { + db := dbtest.Postgres(t) + defer db.Close() + + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + assert.Equal(t, "postgres", dbConnectionPool.DriverName()) + + err = dbConnectionPool.Ping() + require.NoError(t, err) +} diff --git a/stellar-auth/internal/db/dbtest/dbtest.go b/stellar-auth/internal/db/dbtest/dbtest.go new file mode 100644 index 000000000..42d2232b6 --- /dev/null +++ b/stellar-auth/internal/db/dbtest/dbtest.go @@ -0,0 +1,31 @@ +package dbtest + +import ( + "net/http" + "testing" + + migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/go/support/db/dbtest" + "github.com/stellar/go/support/db/schema" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/migrations" +) + +func OpenWithoutMigrations(t *testing.T) *dbtest.DB { + db := dbtest.Postgres(t) + return db +} + +func Open(t *testing.T) *dbtest.DB { + db := OpenWithoutMigrations(t) + + conn := db.Open() + defer conn.Close() + + migrateDirection := schema.MigrateUp + m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrations.FS)} + _, err := schema.Migrate(conn.DB, m, migrateDirection, 0) + if err != nil { + t.Fatal(err) + } + return db +} diff --git a/stellar-auth/internal/db/dbtest/dbtest_test.go b/stellar-auth/internal/db/dbtest/dbtest_test.go new file mode 100644 index 000000000..6ade6b7fa --- /dev/null +++ b/stellar-auth/internal/db/dbtest/dbtest_test.go @@ -0,0 +1,21 @@ +package dbtest + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpen(t *testing.T) { + db := Open(t) + defer db.Close() + + session := db.Open() + defer session.Close() + + count := 0 + err := session.Get(&count, `SELECT COUNT(*) FROM gorp_migrations`) + require.NoError(t, err) + assert.Greater(t, count, 0) +} diff --git a/stellar-auth/internal/db/migrate.go b/stellar-auth/internal/db/migrate.go new file mode 100644 index 000000000..97a267e3f --- /dev/null +++ b/stellar-auth/internal/db/migrate.go @@ -0,0 +1,27 @@ +package db + +import ( + "fmt" + "net/http" + + migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/migrations" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/utils" +) + +const StellarAuthMigrationsTableName = "auth_migrations" + +func Migrate(dbURL string, dir migrate.MigrationDirection, count int) (int, error) { + dbConnectionPool, err := OpenDBConnectionPool(dbURL) + if err != nil { + return 0, fmt.Errorf("database URL '%s': %w", utils.TruncateString(dbURL, len(dbURL)/4), err) + } + defer dbConnectionPool.Close() + + ms := migrate.MigrationSet{ + TableName: StellarAuthMigrationsTableName, + } + + m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrations.FS)} + return ms.ExecMax(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName(), m, dir, count) +} diff --git a/stellar-auth/internal/db/migrate_test.go b/stellar-auth/internal/db/migrate_test.go new file mode 100644 index 000000000..3c4775985 --- /dev/null +++ b/stellar-auth/internal/db/migrate_test.go @@ -0,0 +1,93 @@ +package db + +import ( + "context" + "fmt" + "io/fs" + "testing" + + migrate "github.com/rubenv/sql-migrate" + "github.com/stellar/stellar-disbursement-platform-backend/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/migrations" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMigrate_upApplyOne(t *testing.T) { + db := dbtest.OpenWithoutMigrations(t) + defer db.Close() + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + n, err := Migrate(db.DSN, migrate.Up, 1) + require.NoError(t, err) + assert.Equal(t, 1, n) + + ids := []string{} + err = dbConnectionPool.SelectContext(ctx, &ids, fmt.Sprintf("SELECT id FROM %s", StellarAuthMigrationsTableName)) + require.NoError(t, err) + wantIDs := []string{"2023-02-09.0.add-users-table.sql"} + assert.Equal(t, wantIDs, ids) +} + +func TestMigrate_downApplyOne(t *testing.T) { + db := dbtest.OpenWithoutMigrations(t) + defer db.Close() + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + n, err := Migrate(db.DSN, migrate.Up, 2) + require.NoError(t, err) + require.Equal(t, 2, n) + + n, err = Migrate(db.DSN, migrate.Down, 1) + require.NoError(t, err) + require.Equal(t, 1, n) + + ids := []string{} + err = dbConnectionPool.SelectContext(ctx, &ids, fmt.Sprintf("SELECT id FROM %s", StellarAuthMigrationsTableName)) + require.NoError(t, err) + wantIDs := []string{"2023-02-09.0.add-users-table.sql"} + assert.Equal(t, wantIDs, ids) +} + +func TestMigrate_upAndDownAllTheWayTwice(t *testing.T) { + db := dbtest.OpenWithoutMigrations(t) + defer db.Close() + dbConnectionPool, err := OpenDBConnectionPool(db.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + // Get number of files in the migrations directory: + var count int + err = fs.WalkDir(migrations.FS, ".", func(path string, d fs.DirEntry, err error) error { + require.NoError(t, err) + if !d.IsDir() { + count++ + } + return nil + }) + require.NoError(t, err) + + n, err := Migrate(db.DSN, migrate.Up, count) + require.NoError(t, err) + require.Equal(t, count, n) + + n, err = Migrate(db.DSN, migrate.Down, count) + require.NoError(t, err) + require.Equal(t, count, n) + + n, err = Migrate(db.DSN, migrate.Up, count) + require.NoError(t, err) + require.Equal(t, count, n) + + n, err = Migrate(db.DSN, migrate.Down, count) + require.NoError(t, err) + require.Equal(t, count, n) +} diff --git a/stellar-auth/internal/db/migrations/2023-02-09.0.add-users-table.sql b/stellar-auth/internal/db/migrations/2023-02-09.0.add-users-table.sql new file mode 100644 index 000000000..981f420da --- /dev/null +++ b/stellar-auth/internal/db/migrations/2023-02-09.0.add-users-table.sql @@ -0,0 +1,21 @@ +-- +migrate Up + +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +CREATE TABLE + public.auth_users ( + id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(), + username text NOT NULL, + encrypted_password text NOT NULL, + email text NOT NULL, + is_owner boolean NOT NULL DEFAULT false, + created_at TIMESTAMP + WITH + TIME ZONE NOT NULL DEFAULT NOW(), + UNIQUE (username), + UNIQUE (email) + ); + +-- +migrate Down + +DROP TABLE public.auth_users; diff --git a/stellar-auth/internal/db/migrations/2023-03-07.0.add-password-reset-table.sql b/stellar-auth/internal/db/migrations/2023-03-07.0.add-password-reset-table.sql new file mode 100644 index 000000000..6e3b8df35 --- /dev/null +++ b/stellar-auth/internal/db/migrations/2023-03-07.0.add-password-reset-table.sql @@ -0,0 +1,44 @@ +-- +migrate Up + +CREATE TABLE + public.auth_user_password_reset ( + token text NOT NULL UNIQUE, + auth_user_id VARCHAR(36) NOT NULL, + is_valid boolean NOT NULL DEFAULT true, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_password_reset_auth_user_id + FOREIGN KEY (auth_user_id) + REFERENCES auth_users(id) + ); + +CREATE UNIQUE INDEX unique_user_valid_token ON auth_user_password_reset(auth_user_id, is_valid) WHERE (is_valid IS TRUE); + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION auth_user_password_reset_before_insert() +RETURNS TRIGGER AS $auth_user_password_reset_before_insert$ +BEGIN + UPDATE + auth_user_password_reset + SET + is_valid = false + WHERE + auth_user_id = NEW.auth_user_id; + + RETURN NEW; +END; +$auth_user_password_reset_before_insert$ LANGUAGE plpgsql; + + +CREATE TRIGGER auth_user_password_reset_before_insert_trigger +BEFORE INSERT +ON auth_user_password_reset +FOR EACH ROW +EXECUTE PROCEDURE auth_user_password_reset_before_insert(); +-- +migrate StatementEnd + +-- +migrate Down + +DROP TRIGGER auth_user_password_reset_before_insert_trigger ON auth_user_password_reset; +DROP FUNCTION IF EXISTS auth_user_password_reset_before_insert; +DROP INDEX IF EXISTS unique_user_valid_token; +DROP TABLE public.auth_user_password_reset; diff --git a/stellar-auth/internal/db/migrations/2023-03-10.0.alter-users-table-add-roles-column.sql b/stellar-auth/internal/db/migrations/2023-03-10.0.alter-users-table-add-roles-column.sql new file mode 100644 index 000000000..ccd561e56 --- /dev/null +++ b/stellar-auth/internal/db/migrations/2023-03-10.0.alter-users-table-add-roles-column.sql @@ -0,0 +1,7 @@ +-- +migrate Up + +ALTER TABLE auth_users ADD COLUMN roles text[]; + +-- +migrate Down + +ALTER TABLE auth_users DROP COLUMN roles; diff --git a/stellar-auth/internal/db/migrations/2023-03-22.0.alter-users-table-add-is_active-column.sql b/stellar-auth/internal/db/migrations/2023-03-22.0.alter-users-table-add-is_active-column.sql new file mode 100644 index 000000000..4c3697690 --- /dev/null +++ b/stellar-auth/internal/db/migrations/2023-03-22.0.alter-users-table-add-is_active-column.sql @@ -0,0 +1,7 @@ +-- +migrate Up + +ALTER TABLE public.auth_users ADD COLUMN is_active boolean DEFAULT true; + +-- +migrate Down + +ALTER TABLE public.auth_users DROP COLUMN is_active; diff --git a/stellar-auth/internal/db/migrations/2023-03-28.0.alter-users-table-add-new-columns-and-drop-username-column.sql b/stellar-auth/internal/db/migrations/2023-03-28.0.alter-users-table-add-new-columns-and-drop-username-column.sql new file mode 100644 index 000000000..c4e661740 --- /dev/null +++ b/stellar-auth/internal/db/migrations/2023-03-28.0.alter-users-table-add-new-columns-and-drop-username-column.sql @@ -0,0 +1,16 @@ +-- +migrate Up + +ALTER TABLE public.auth_users + ADD COLUMN first_name VARCHAR(128) NOT NULL DEFAULT '', + ADD COLUMN last_name VARCHAR(128) NOT NULL DEFAULT ''; + +ALTER TABLE public.auth_users DROP COLUMN username; + +-- +migrate Down + +ALTER TABLE public.auth_users + DROP COLUMN first_name, + DROP COLUMN last_name; + +ALTER TABLE public.auth_users + ADD COLUMN username VARCHAR(128) UNIQUE; diff --git a/stellar-auth/internal/db/migrations/2023-07-20.0-create-auth_user_mfa_codes_table.sql b/stellar-auth/internal/db/migrations/2023-07-20.0-create-auth_user_mfa_codes_table.sql new file mode 100644 index 000000000..70becf72d --- /dev/null +++ b/stellar-auth/internal/db/migrations/2023-07-20.0-create-auth_user_mfa_codes_table.sql @@ -0,0 +1,36 @@ +-- +migrate Up +CREATE TABLE auth_user_mfa_codes +( + device_id TEXT NOT NULL, + auth_user_id VARCHAR(36) NOT NULL + CONSTRAINT fk_mfa_codes_auth_user_id REFERENCES auth_users, + code VARCHAR(8), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + device_expires_at TIMESTAMP WITH TIME ZONE, + code_expires_at TIMESTAMP WITH TIME ZONE, + CONSTRAINT auth_user_mfa_codes_pkey PRIMARY KEY (device_id, auth_user_id), + UNIQUE (device_id, code) +); + +-- +migrate StatementBegin +CREATE OR REPLACE FUNCTION auth_user_mfa_codes_before_update() + RETURNS TRIGGER AS $auth_user_mfa_codes_before_update$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$auth_user_mfa_codes_before_update$ LANGUAGE plpgsql; + + +CREATE TRIGGER auth_user_mfa_codes_before_update_trigger + BEFORE UPDATE + ON auth_user_mfa_codes + FOR EACH ROW +EXECUTE PROCEDURE auth_user_mfa_codes_before_update(); +-- +migrate StatementEnd + +-- +migrate Down +DROP TRIGGER auth_user_mfa_codes_before_update_trigger ON auth_user_mfa_codes; +DROP FUNCTION auth_user_mfa_codes_before_update(); +DROP TABLE auth_user_mfa_codes; \ No newline at end of file diff --git a/stellar-auth/internal/db/migrations/main.go b/stellar-auth/internal/db/migrations/main.go new file mode 100644 index 000000000..91cca1c33 --- /dev/null +++ b/stellar-auth/internal/db/migrations/main.go @@ -0,0 +1,6 @@ +package migrations + +import "embed" + +//go:embed *.sql +var FS embed.FS diff --git a/stellar-auth/pkg/auth/auth.go b/stellar-auth/pkg/auth/auth.go new file mode 100644 index 000000000..0536b5553 --- /dev/null +++ b/stellar-auth/pkg/auth/auth.go @@ -0,0 +1,353 @@ +package auth + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" +) + +type AuthManager interface { + Authenticate(ctx context.Context, email, pass string) (string, error) + RefreshToken(ctx context.Context, tokenString string) (string, error) + ValidateToken(ctx context.Context, tokenString string) (bool, error) + AllRolesInTokenUser(ctx context.Context, tokenString string, roleNames []string) (bool, error) + AnyRolesInTokenUser(ctx context.Context, tokenString string, roleNames []string) (bool, error) + CreateUser(ctx context.Context, user *User, password string) (*User, error) + UpdateUser(ctx context.Context, tokenString, firstName, lastName, email, password string) error + ForgotPassword(ctx context.Context, email string) (string, error) + ResetPassword(ctx context.Context, tokenString, password string) error + GetUser(ctx context.Context, tokenString string) (*User, error) + GetAllUsers(ctx context.Context, tokenString string) ([]User, error) + UpdateUserRoles(ctx context.Context, tokenString, userID string, roles []string) error + DeactivateUser(ctx context.Context, tokenString, userID string) error + ActivateUser(ctx context.Context, tokenString, userID string) error + ExpirationTimeInMinutes() time.Duration + MFADeviceRemembered(ctx context.Context, deviceID, userID string) (bool, error) + GetMFACode(ctx context.Context, deviceID, userID string) (string, error) + AuthenticateMFA(ctx context.Context, deviceID, code string, rememberMe bool) (string, error) +} + +// DBConnectionPoolFromSqlDB returns a new DBConnectionPool wrapper for a PRE-EXISTING *sql.DB. The driverName of the +// original database is required for named query support. ATTENTION: this will not start a new connection pool, just +// create a wrap aroung the pre-existing connection pool. +func DBConnectionPoolFromSqlDB(sqlDB *sql.DB, driverName string) db.DBConnectionPool { + return db.DBConnectionPoolFromSqlDB(sqlDB, driverName) +} + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrInvalidMFACode = errors.New("invalid MFA code") +) + +func (am *defaultAuthManager) Authenticate(ctx context.Context, email, pass string) (string, error) { + user, err := am.authenticator.ValidateCredentials(ctx, email, pass) + if errors.Is(err, ErrInvalidCredentials) { + return "", err + } + if err != nil { + return "", fmt.Errorf("validating credentials: %w", err) + } + + return am.generateToken(ctx, user) +} + +func (am *defaultAuthManager) generateToken(ctx context.Context, user *User) (string, error) { + roles, err := am.roleManager.GetUserRoles(ctx, user) + if err != nil { + return "", fmt.Errorf("error getting user roles: %w", err) + } + + user.Roles = roles + + expiresAt := time.Now().Add(am.expirationTimeInMinutes) + tokenString, err := am.jwtManager.GenerateToken(ctx, user, expiresAt) + if err != nil { + return "", fmt.Errorf("generating token: %w", err) + } + + return tokenString, nil +} + +func (am *defaultAuthManager) RefreshToken(ctx context.Context, tokenString string) (string, error) { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return "", fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return "", ErrInvalidToken + } + + // TODO: find a way to not refresh the same token + // more than once - perhaps create a table and store invalid tokens + expiresAt := time.Now().Add(am.expirationTimeInMinutes) + tokenString, err = am.jwtManager.RefreshToken(ctx, tokenString, expiresAt) + if err != nil { + return "", fmt.Errorf("generating new refreshed token: %w", err) + } + + return tokenString, nil +} + +func (am *defaultAuthManager) ValidateToken(ctx context.Context, tokenString string) (bool, error) { + isValid, err := am.jwtManager.ValidateToken(ctx, tokenString) + if err != nil { + return false, fmt.Errorf("validating token: %w", err) + } + + return isValid, nil +} + +// AllRolesInTokenUser checks whether the user's token has all the roles passed by parameter. +func (am *defaultAuthManager) AllRolesInTokenUser(ctx context.Context, tokenString string, roleNames []string) (bool, error) { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return false, fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return false, ErrInvalidToken + } + + user, err := am.jwtManager.GetUserFromToken(ctx, tokenString) + if err != nil { + return false, fmt.Errorf("error getting user from token: %w", err) + } + + hasAllRoles, err := am.roleManager.HasAllRoles(ctx, user, roleNames) + if err != nil { + return false, fmt.Errorf("error validating user roles: %w", err) + } + + return hasAllRoles, nil +} + +// AnyRolesInTokenUser checks whether the user's token has one or more the roles passed by parameter. +func (am *defaultAuthManager) AnyRolesInTokenUser(ctx context.Context, tokenString string, roleNames []string) (bool, error) { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return false, fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return false, ErrInvalidToken + } + + user, err := am.jwtManager.GetUserFromToken(ctx, tokenString) + if err != nil { + return false, fmt.Errorf("error getting user from token: %w", err) + } + + hasAnyRoles, err := am.roleManager.HasAnyRoles(ctx, user, roleNames) + if err != nil { + return false, fmt.Errorf("error validating user roles: %w", err) + } + + return hasAnyRoles, nil +} + +// CreateUser creates a new user using Authenticator's CreateUser method. +func (am *defaultAuthManager) CreateUser(ctx context.Context, user *User, password string) (*User, error) { + user, err := am.authenticator.CreateUser(ctx, user, password) + if err != nil { + return nil, fmt.Errorf("error creating user: %w", err) + } + + return user, nil +} + +func (am *defaultAuthManager) UpdateUser(ctx context.Context, tokenString, firstName, lastName, email, password string) error { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return ErrInvalidToken + } + + user, err := am.jwtManager.GetUserFromToken(ctx, tokenString) + if err != nil { + return fmt.Errorf("error getting user from token: %w", err) + } + + err = am.authenticator.UpdateUser(ctx, user.ID, firstName, lastName, email, password) + if err != nil { + return fmt.Errorf("error updating user: %w", err) + } + + return nil +} + +// ForgotPassword handles the generation of a new password reset token for the user to set a new password. +func (am *defaultAuthManager) ForgotPassword(ctx context.Context, email string) (string, error) { + resetToken, err := am.authenticator.ForgotPassword(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return "", fmt.Errorf("user not found in auth forgot password: %w", err) + } + return "", fmt.Errorf("error on forgot password: %w", err) + } + + return resetToken, nil +} + +// ResetPassword sets the user's new password using a valid reset token generated in the ForgotPassword flow. +func (am *defaultAuthManager) ResetPassword(ctx context.Context, resetToken, newPassword string) error { + err := am.authenticator.ResetPassword(ctx, resetToken, newPassword) + if err != nil { + if errors.Is(err, ErrInvalidResetPasswordToken) { + return fmt.Errorf("invalid token in auth reset password: %w", err) + } + return fmt.Errorf("error on reset password: %w", err) + } + + return nil +} + +func (am *defaultAuthManager) ActivateUser(ctx context.Context, tokenString, userID string) error { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return ErrInvalidToken + } + + err = am.authenticator.ActivateUser(ctx, userID) + if err != nil { + return fmt.Errorf("error activating user ID %s: %w", userID, err) + } + + return nil +} + +func (am *defaultAuthManager) DeactivateUser(ctx context.Context, tokenString, userID string) error { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return ErrInvalidToken + } + + err = am.authenticator.DeactivateUser(ctx, userID) + if err != nil { + return fmt.Errorf("error deactivating user ID %s: %w", userID, err) + } + + return nil +} + +func (am *defaultAuthManager) UpdateUserRoles(ctx context.Context, tokenString, userID string, roles []string) error { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return ErrInvalidToken + } + + // TODO: pass all fields of the user + err = am.roleManager.UpdateRoles(ctx, &User{ID: userID}, roles) + if err != nil { + return fmt.Errorf("error updating user roles: %w", err) + } + + return nil +} + +func (am *defaultAuthManager) GetAllUsers(ctx context.Context, tokenString string) ([]User, error) { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return nil, fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return nil, ErrInvalidToken + } + + users, err := am.authenticator.GetAllUsers(ctx) + if err != nil { + return nil, fmt.Errorf("error getting all users: %w", err) + } + + return users, nil +} + +func (am *defaultAuthManager) GetUser(ctx context.Context, tokenString string) (*User, error) { + isValid, err := am.ValidateToken(ctx, tokenString) + if err != nil { + return nil, fmt.Errorf("validating token: %w", err) + } + + if !isValid { + return nil, ErrInvalidToken + } + + tokenUser, err := am.jwtManager.GetUserFromToken(ctx, tokenString) + if err != nil { + return nil, fmt.Errorf("error getting user from token: %w", err) + } + + // We get the user latest state + user, err := am.authenticator.GetUser(ctx, tokenUser.ID) + if err != nil { + return nil, fmt.Errorf("error getting user ID %s: %w", tokenUser.ID, err) + } + + roles, err := am.roleManager.GetUserRoles(ctx, user) + if err != nil { + return nil, fmt.Errorf("error getting user ID %s roles: %w", tokenUser.ID, err) + } + + user.Roles = roles + + return user, nil +} + +func (am *defaultAuthManager) ExpirationTimeInMinutes() time.Duration { + return am.expirationTimeInMinutes +} + +func (am *defaultAuthManager) MFADeviceRemembered(ctx context.Context, deviceID, userID string) (bool, error) { + return am.mfaManager.MFADeviceRemembered(ctx, deviceID, userID) +} + +func (am *defaultAuthManager) GetMFACode(ctx context.Context, deviceID, userID string) (string, error) { + return am.mfaManager.GenerateMFACode(ctx, deviceID, userID) +} + +func (am *defaultAuthManager) AuthenticateMFA(ctx context.Context, deviceID, code string, rememberMe bool) (string, error) { + if rememberMe { + err := am.mfaManager.RememberDevice(ctx, deviceID, code) + if err != nil { + return "", fmt.Errorf("error remembering device ID %s: %w", deviceID, err) + } + } + + userID, err := am.mfaManager.ValidateMFACode(ctx, deviceID, code) + if err != nil { + if errors.Is(err, ErrMFACodeInvalid) { + return "", ErrInvalidMFACode + } + return "", fmt.Errorf("error validating MFA code: %w", err) + } + + user, err := am.authenticator.GetUser(ctx, userID) + if err != nil { + return "", fmt.Errorf("error getting user ID %s: %w", userID, err) + } + + return am.generateToken(ctx, user) +} + +// Ensuring that defaultAuthManager is implementing AuthManager interface +var _ AuthManager = (*defaultAuthManager)(nil) diff --git a/stellar-auth/pkg/auth/auth_test.go b/stellar-auth/pkg/auth/auth_test.go new file mode 100644 index 000000000..c3ba10479 --- /dev/null +++ b/stellar-auth/pkg/auth/auth_test.go @@ -0,0 +1,1260 @@ +package auth + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_AuthManager_Authenticate(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + jwtManagerMock := &JWTManagerMock{} + roleManagerMock := &RoleManagerMock{} + + authManager := NewAuthManager( + WithCustomAuthenticatorOption(authenticatorMock), + WithCustomJWTManagerOption(jwtManagerMock), + WithCustomRoleManagerOption(roleManagerMock), + ) + + ctx := context.Background() + + t.Run("returns error when invalid credentials is provided", func(t *testing.T) { + email, password := "email@email.com", "pass123" + + authenticatorMock. + On("ValidateCredentials", ctx, email, password). + Return(nil, errors.New("invalid credentials")). + Once() + + token, err := authManager.Authenticate(ctx, email, password) + + assert.EqualError(t, err, "validating credentials: invalid credentials") + assert.Empty(t, token) + }) + + t.Run("returns error when get user roles fails", func(t *testing.T) { + email, password := "email@email.com", "pass123" + + expectedUser := &User{ + ID: "user-id", + Email: "email@email.com", + } + + authenticatorMock. + On("ValidateCredentials", ctx, email, password). + Return(expectedUser, nil). + Once() + + roleManagerMock. + On("GetUserRoles", ctx, expectedUser). + Return(nil, errUnexpectedError). + Once() + + token, err := authManager.Authenticate(ctx, email, password) + + assert.EqualError(t, err, "error getting user roles: unexpected error") + assert.Empty(t, token) + }) + + t.Run("returns error when generate token fails", func(t *testing.T) { + email, password := "email@email.com", "pass123" + + expectedUser := &User{ + ID: "user-id", + Email: "email@email.com", + } + + authenticatorMock. + On("ValidateCredentials", ctx, email, password). + Return(expectedUser, nil). + Once() + + roleManagerMock. + On("GetUserRoles", ctx, expectedUser). + Return([]string{"role1"}, nil). + Once() + + jwtManagerMock. + On("GenerateToken", ctx, expectedUser, mock.AnythingOfType("time.Time")). + Return("", errUnexpectedError). + Once() + + token, err := authManager.Authenticate(ctx, email, password) + + assert.EqualError(t, err, "generating token: unexpected error") + assert.Empty(t, token) + }) + + t.Run("returns the user JWT token successfully", func(t *testing.T) { + email, password := "email@email.com", "pass123" + + user := &User{ + ID: "user-id", + Email: "email@email.com", + } + + roles := []string{"role1"} + + expectedUser := &User{ + ID: "user-id", + Email: "email@email.com", + Roles: roles, + } + + authenticatorMock. + On("ValidateCredentials", ctx, email, password). + Return(user, nil). + Once() + + roleManagerMock. + On("GetUserRoles", ctx, user). + Return(roles, nil). + Once() + + expectedToken := "mytoken" + jwtManagerMock. + On("GenerateToken", ctx, expectedUser, mock.AnythingOfType("time.Time")). + Return(expectedToken, nil). + Once() + + token, err := authManager.Authenticate(ctx, email, password) + require.NoError(t, err) + + assert.Equal(t, expectedToken, token) + }) + + authenticatorMock.AssertExpectations(t) + jwtManagerMock.AssertExpectations(t) + roleManagerMock.AssertExpectations(t) +} + +func Test_AuthManager_ValidateToken(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + isValid, err := authManager.ValidateToken(ctx, token) + + assert.EqualError(t, err, "validating token: unexpected error") + assert.False(t, isValid) + }) + + t.Run("returns false when token is invalid", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + isValid, err := authManager.ValidateToken(ctx, token) + require.NoError(t, err) + + assert.False(t, isValid) + }) + + t.Run("returns true when token is valid", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil) + + isValid, err := authManager.ValidateToken(ctx, token) + require.NoError(t, err) + + assert.True(t, isValid) + }) + + jwtManagerMock.AssertExpectations(t) +} + +func Test_AuthManager_RefreshToken(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + refreshedToken, err := authManager.RefreshToken(ctx, token) + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + assert.Empty(t, refreshedToken) + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + refreshedToken, err := authManager.RefreshToken(ctx, token) + + assert.EqualError(t, err, ErrInvalidToken.Error()) + assert.Empty(t, refreshedToken) + }) + + t.Run("returns error when JWT Manager fails", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("RefreshToken", ctx, token, mock.AnythingOfType("time.Time")). + Return("", errUnexpectedError). + Once() + + refreshedToken, err := authManager.RefreshToken(ctx, token) + + assert.EqualError(t, err, "generating new refreshed token: unexpected error") + assert.Empty(t, refreshedToken) + }) + + t.Run("returns a new token successfully", func(t *testing.T) { + token := "myoldtoken" + newToken := "myfreshtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("RefreshToken", ctx, token, mock.AnythingOfType("time.Time")). + Return(newToken, nil). + Once() + + refreshedToken, err := authManager.RefreshToken(ctx, token) + require.NoError(t, err) + + assert.NotEqual(t, token, refreshedToken) + assert.Equal(t, newToken, refreshedToken) + }) + + jwtManagerMock.AssertExpectations(t) +} + +func Test_AuthManager_AllRolesInTokenUser(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + roleManagerMock := &RoleManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomRoleManagerOption(roleManagerMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + isValid, err := authManager.AllRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + assert.False(t, isValid) + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + isValid, err := authManager.AllRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, ErrInvalidToken.Error()) + assert.False(t, isValid) + }) + + t.Run("returns error when JWT Manager fails getting user from token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(nil, errUnexpectedError). + Once() + + isValid, err := authManager.AllRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, "error getting user from token: unexpected error") + assert.False(t, isValid) + }) + + t.Run("returns error when Role Manager fails verifying if user has roles", func(t *testing.T) { + token := "myoldtoken" + + user := &User{ + ID: "user-ID", + Email: "email@email.com", + Roles: []string{"role1"}, + } + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(user, nil). + Once() + + roleManagerMock. + On("HasAllRoles", ctx, user, []string{"role1"}). + Return(false, errUnexpectedError). + Once() + + isValid, err := authManager.AllRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, "error validating user roles: unexpected error") + assert.False(t, isValid) + }) + + t.Run("validates the user roles correctly", func(t *testing.T) { + token := "myoldtoken" + + user := &User{ + ID: "user-ID", + Email: "email@email.com", + Roles: []string{"role1", "role3"}, + } + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Times(4). + On("GetUserFromToken", ctx, token). + Return(user, nil). + Times(4) + + roleManagerMock. + On("HasAllRoles", ctx, user, []string{"role1"}). + Return(true, nil). + Once() + + isValid, err := authManager.AllRolesInTokenUser(ctx, token, []string{"role1"}) + require.NoError(t, err) + assert.True(t, isValid) + + roleManagerMock. + On("HasAllRoles", ctx, user, []string{"role2", "role3"}). + Return(false, nil). + Once() + + isValid, err = authManager.AllRolesInTokenUser(ctx, token, []string{"role2", "role3"}) + require.NoError(t, err) + assert.False(t, isValid) + + roleManagerMock. + On("HasAllRoles", ctx, user, []string{"role2"}). + Return(false, nil). + Once() + + isValid, err = authManager.AllRolesInTokenUser(ctx, token, []string{"role2"}) + require.NoError(t, err) + assert.False(t, isValid) + + roleManagerMock. + On("HasAllRoles", ctx, user, []string{"role1", "role3"}). + Return(true, nil). + Once() + + isValid, err = authManager.AllRolesInTokenUser(ctx, token, []string{"role1", "role3"}) + require.NoError(t, err) + assert.True(t, isValid) + }) + + jwtManagerMock.AssertExpectations(t) + roleManagerMock.AssertExpectations(t) +} + +func Test_AuthManager_AnyRolesInTokenUser(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + roleManagerMock := &RoleManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomRoleManagerOption(roleManagerMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + isValid, err := authManager.AnyRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + assert.False(t, isValid) + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + isValid, err := authManager.AnyRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, ErrInvalidToken.Error()) + assert.False(t, isValid) + }) + + t.Run("returns error when JWT Manager fails getting user from token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(nil, errUnexpectedError). + Once() + + isValid, err := authManager.AnyRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, "error getting user from token: unexpected error") + assert.False(t, isValid) + }) + + t.Run("returns error when Role Manager fails verifying if user has roles", func(t *testing.T) { + token := "myoldtoken" + + user := &User{ + ID: "user-ID", + Email: "email@email.com", + Roles: []string{"role1"}, + } + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(user, nil). + Once() + + roleManagerMock. + On("HasAnyRoles", ctx, user, []string{"role1"}). + Return(false, errUnexpectedError). + Once() + + isValid, err := authManager.AnyRolesInTokenUser(ctx, token, []string{"role1"}) + + assert.EqualError(t, err, "error validating user roles: unexpected error") + assert.False(t, isValid) + }) + + t.Run("validates the user roles correctly", func(t *testing.T) { + token := "myoldtoken" + + user := &User{ + ID: "user-ID", + Email: "email@email.com", + Roles: []string{"role1", "role3"}, + } + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Times(4). + On("GetUserFromToken", ctx, token). + Return(user, nil). + Times(4) + + roleManagerMock. + On("HasAnyRoles", ctx, user, []string{"role1"}). + Return(true, nil). + Once() + + isValid, err := authManager.AnyRolesInTokenUser(ctx, token, []string{"role1"}) + require.NoError(t, err) + assert.True(t, isValid) + + roleManagerMock. + On("HasAnyRoles", ctx, user, []string{"role2", "role3"}). + Return(true, nil). + Once() + + isValid, err = authManager.AnyRolesInTokenUser(ctx, token, []string{"role2", "role3"}) + require.NoError(t, err) + assert.True(t, isValid) + + roleManagerMock. + On("HasAnyRoles", ctx, user, []string{"role2"}). + Return(false, nil). + Once() + + isValid, err = authManager.AnyRolesInTokenUser(ctx, token, []string{"role2"}) + require.NoError(t, err) + assert.False(t, isValid) + + roleManagerMock. + On("HasAnyRoles", ctx, user, []string{"role1", "role3"}). + Return(true, nil). + Once() + + isValid, err = authManager.AnyRolesInTokenUser(ctx, token, []string{"role1", "role3"}) + require.NoError(t, err) + assert.True(t, isValid) + }) + + jwtManagerMock.AssertExpectations(t) + roleManagerMock.AssertExpectations(t) +} + +func Test_AuthManager_CreateUser(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + + authManager := NewAuthManager( + WithCustomAuthenticatorOption(authenticatorMock), + ) + + ctx := context.Background() + + t.Run("returns error when Authenticator fails creating user", func(t *testing.T) { + user := &User{ + Email: "email@email.com", + FirstName: "First", + LastName: "Last", + } + + password := "mysecret" + + authenticatorMock. + On("CreateUser", ctx, user, password). + Return(nil, errUnexpectedError). + Once() + + u, err := authManager.CreateUser(ctx, user, password) + + assert.EqualError(t, err, "error creating user: unexpected error") + assert.Nil(t, u) + }) + + t.Run("create user correctly", func(t *testing.T) { + newUser := &User{ + Email: "email@email.com", + FirstName: "First", + LastName: "Last", + } + + password := "mysecret" + + expectedUser := &User{ + ID: "user-id", + Email: "email@email.com", + FirstName: "First", + LastName: "Last", + } + + authenticatorMock. + On("CreateUser", ctx, newUser, password). + Return(expectedUser, nil). + Once() + + u, err := authManager.CreateUser(ctx, newUser, password) + require.NoError(t, err) + + assert.Equal(t, expectedUser, u) + }) + + authenticatorMock.AssertExpectations(t) +} + +func Test_AuthManager_ActivateUser(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + jwtManagerMock := &JWTManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomAuthenticatorOption(authenticatorMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + err := authManager.ActivateUser(ctx, token, "user-id") + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + err := authManager.ActivateUser(ctx, token, "user-id") + + assert.EqualError(t, err, "invalid token") + }) + + t.Run("returns error when Authenticator fails", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + authenticatorMock. + On("ActivateUser", ctx, userID). + Return(errUnexpectedError). + Once() + + err := authManager.ActivateUser(ctx, token, userID) + + assert.EqualError(t, err, "error activating user ID user-id: unexpected error") + }) + + t.Run("activate user successfully", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + authenticatorMock. + On("ActivateUser", ctx, userID). + Return(nil). + Once() + + err := authManager.ActivateUser(ctx, token, userID) + + assert.Nil(t, err) + }) +} + +func Test_AuthManager_UpdateUser(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + jwtManagerMock := &JWTManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomAuthenticatorOption(authenticatorMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + err := authManager.UpdateUser(ctx, token, "First", "Last", "email@email.com", "mysecret") + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(nil, errUnexpectedError). + Once() + + err = authManager.UpdateUser(ctx, token, "First", "Last", "email@email.com", "mysecret") + + assert.EqualError(t, err, "error getting user from token: unexpected error") + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + err := authManager.UpdateUser(ctx, token, "First", "Last", "email@email.com", "mysecret") + + assert.EqualError(t, err, "invalid token") + }) + + t.Run("returns error when Authenticator fails", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + firstName, lastName, email, password := "First", "Last", "email@email.com", "mysecret" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(&User{ID: userID}, nil). + Once() + + authenticatorMock. + On("UpdateUser", ctx, userID, firstName, lastName, email, password). + Return(errUnexpectedError). + Once() + + err := authManager.UpdateUser(ctx, token, "First", "Last", "email@email.com", "mysecret") + + assert.EqualError(t, err, "error updating user: unexpected error") + }) + + t.Run("updates user successfully", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + firstName, lastName, email, password := "First", "Last", "email@email.com", "mysecret" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(&User{ID: userID}, nil). + Once() + + authenticatorMock. + On("UpdateUser", ctx, userID, firstName, lastName, email, password). + Return(nil). + Once() + + err := authManager.UpdateUser(ctx, token, "First", "Last", "email@email.com", "mysecret") + + assert.Nil(t, err) + }) +} + +func Test_AuthManager_DeactivateUser(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + jwtManagerMock := &JWTManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomAuthenticatorOption(authenticatorMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + err := authManager.DeactivateUser(ctx, token, "user-id") + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + err := authManager.DeactivateUser(ctx, token, "user-id") + + assert.EqualError(t, err, "invalid token") + }) + + t.Run("returns error when Authenticator fails", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + authenticatorMock. + On("DeactivateUser", ctx, userID). + Return(errUnexpectedError). + Once() + + err := authManager.DeactivateUser(ctx, token, userID) + + assert.EqualError(t, err, "error deactivating user ID user-id: unexpected error") + }) + + t.Run("deactivate user successfully", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + authenticatorMock. + On("DeactivateUser", ctx, userID). + Return(nil). + Once() + + err := authManager.DeactivateUser(ctx, token, userID) + + assert.Nil(t, err) + }) +} + +func Test_AuthManager_UpdateUserRoles(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + roleManagerMock := &RoleManagerMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomRoleManagerOption(roleManagerMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + err := authManager.UpdateUserRoles(ctx, token, "user-id", []string{"role1"}) + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + err := authManager.UpdateUserRoles(ctx, token, "user-id", []string{"role1"}) + + assert.EqualError(t, err, "invalid token") + }) + + t.Run("returns error when Authenticator fails", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + roles := []string{"role1"} + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + roleManagerMock. + On("UpdateRoles", ctx, &User{ID: userID}, roles). + Return(errUnexpectedError). + Once() + + err := authManager.UpdateUserRoles(ctx, token, userID, roles) + + assert.EqualError(t, err, "error updating user roles: unexpected error") + }) + + t.Run("update user roles successfully", func(t *testing.T) { + token := "mytoken" + userID := "user-id" + roles := []string{"role1"} + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + roleManagerMock. + On("UpdateRoles", ctx, &User{ID: userID}, roles). + Return(nil). + Once() + + err := authManager.UpdateUserRoles(ctx, token, userID, roles) + + assert.Nil(t, err) + }) +} + +func Test_AuthManager_WithExpirationTimeInMinutesOption(t *testing.T) { + authManager := NewAuthManager(WithExpirationTimeInMinutesOption(10)) + assert.Equal(t, time.Minute*10, authManager.ExpirationTimeInMinutes()) +} + +func Test_AuthManager_ForgotPassword(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + authManager := NewAuthManager( + WithCustomAuthenticatorOption(authenticatorMock), + ) + + ctx := context.Background() + + t.Run("returns error when user is not found", func(t *testing.T) { + authenticatorMock. + On("ForgotPassword", ctx, "wrongemail@email.com"). + Return("", ErrUserNotFound). + Once() + + resetToken, err := authManager.ForgotPassword(ctx, "wrongemail@email.com") + assert.EqualError(t, err, "user not found in auth forgot password: user not found") + assert.Empty(t, resetToken) + }) + + t.Run("returns error when authenticator fails", func(t *testing.T) { + authenticatorMock. + On("ForgotPassword", ctx, "wrongemail@email.com"). + Return("", errUnexpectedError). + Once() + + resetToken, err := authManager.ForgotPassword(ctx, "wrongemail@email.com") + assert.EqualError(t, err, "error on forgot password: unexpected error") + assert.Empty(t, resetToken) + }) + + t.Run("creates a reset token successfully", func(t *testing.T) { + authenticatorMock. + On("ForgotPassword", ctx, "valid@email.com"). + Return("resettoken", nil). + Once() + + resetToken, err := authManager.ForgotPassword(ctx, "valid@email.com") + require.NoError(t, err) + assert.Equal(t, "resettoken", resetToken) + }) + + authenticatorMock.AssertExpectations(t) +} + +func Test_AuthManager_ResetPassword(t *testing.T) { + authenticatorMock := &AuthenticatorMock{} + authManager := NewAuthManager( + WithCustomAuthenticatorOption(authenticatorMock), + ) + + ctx := context.Background() + + t.Run("returns error when the reset token is invalid", func(t *testing.T) { + authenticatorMock. + On("ResetPassword", ctx, "invalidToken", "password123"). + Return(ErrInvalidResetPasswordToken). + Once() + + err := authManager.ResetPassword(ctx, "invalidToken", "password123") + require.EqualError(t, err, "invalid token in auth reset password: invalid reset password token") + }) + + t.Run("returns error when authenticator fails", func(t *testing.T) { + authenticatorMock. + On("ResetPassword", ctx, "validToken", "password123"). + Return(errUnexpectedError). + Once() + + err := authManager.ResetPassword(ctx, "validToken", "password123") + assert.EqualError(t, err, "error on reset password: unexpected error") + }) + + t.Run("no error with a valid reset token", func(t *testing.T) { + authenticatorMock. + On("ResetPassword", ctx, "goodtoken", "password123"). + Return(nil). + Once() + + err := authManager.ResetPassword(ctx, "goodtoken", "password123") + require.NoError(t, err) + }) + + authenticatorMock.AssertExpectations(t) +} + +func Test_AuthManager_GetAllUsers(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + authenticatorMock := &AuthenticatorMock{} + authManager := NewAuthManager(WithCustomJWTManagerOption(jwtManagerMock), WithCustomAuthenticatorOption(authenticatorMock)) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + users, err := authManager.GetAllUsers(ctx, token) + + assert.Nil(t, users) + assert.EqualError(t, err, "validating token: validating token: unexpected error") + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + users, err := authManager.GetAllUsers(ctx, token) + + assert.Nil(t, users) + assert.EqualError(t, err, "invalid token") + }) + + t.Run("returns error when Authenticator fails", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once() + + authenticatorMock. + On("GetAllUsers", ctx). + Return(nil, errUnexpectedError). + Once() + + users, err := authManager.GetAllUsers(ctx, token) + + assert.EqualError(t, err, "error getting all users: unexpected error") + assert.Nil(t, users) + }) + + t.Run("returns users successfully", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Twice() + + authenticatorMock. + On("GetAllUsers", ctx). + Return([]User{}, nil). + Once() + + users, err := authManager.GetAllUsers(ctx, token) + require.NoError(t, err) + assert.Empty(t, users) + + expectedUsers := []User{ + { + ID: "user1-ID", + FirstName: "First", + LastName: "Last", + Email: "user1@email.com", + IsOwner: false, + IsActive: false, + Roles: []string{"role1"}, + }, + { + ID: "user2-ID", + FirstName: "First", + LastName: "Last", + Email: "user2@email.com", + IsOwner: true, + IsActive: true, + Roles: []string{"role2"}, + }, + } + + authenticatorMock. + On("GetAllUsers", ctx). + Return(expectedUsers, nil). + Once() + + users, err = authManager.GetAllUsers(ctx, token) + require.NoError(t, err) + assert.Equal(t, expectedUsers, users) + }) + + jwtManagerMock.AssertExpectations(t) + authenticatorMock.AssertExpectations(t) +} + +func Test_AuthManager_GetUser(t *testing.T) { + jwtManagerMock := &JWTManagerMock{} + authenticatorMock := &AuthenticatorMock{} + roleManagerMock := &RoleManagerMock{} + authManager := NewAuthManager( + WithCustomJWTManagerOption(jwtManagerMock), + WithCustomAuthenticatorOption(authenticatorMock), + WithCustomRoleManagerOption(roleManagerMock), + ) + + ctx := context.Background() + + t.Run("returns error when JWT Manager fails validating token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, errUnexpectedError). + Once() + + user, err := authManager.GetUser(ctx, token) + + assert.EqualError(t, err, "validating token: validating token: unexpected error") + assert.Nil(t, user) + }) + + t.Run("returns error when token is expired", func(t *testing.T) { + token := "myoldtoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(false, nil). + Once() + + user, err := authManager.GetUser(ctx, token) + + assert.EqualError(t, err, "invalid token") + assert.Nil(t, user) + }) + + t.Run("returns error when JWT Manager fails getting user from token", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(nil, errUnexpectedError). + Once() + + user, err := authManager.GetUser(ctx, token) + + assert.EqualError(t, err, "error getting user from token: unexpected error") + assert.Nil(t, user) + }) + + t.Run("returns error when Authenticator fails", func(t *testing.T) { + token := "mytoken" + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(&User{ + ID: "user-id", + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + }, nil). + Once() + + authenticatorMock. + On("GetUser", ctx, "user-id"). + Return(nil, errUnexpectedError). + Once() + + user, err := authManager.GetUser(ctx, token) + + assert.EqualError(t, err, "error getting user ID user-id: unexpected error") + assert.Nil(t, user) + }) + + t.Run("returns error when get user roles fails", func(t *testing.T) { + token := "mytoken" + + u := &User{ + ID: "user-id", + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + } + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(u, nil). + Once() + + authenticatorMock. + On("GetUser", ctx, u.ID). + Return(u, nil). + Once() + + roleManagerMock. + On("GetUserRoles", ctx, u). + Return(nil, errUnexpectedError). + Once() + + user, err := authManager.GetUser(ctx, token) + + assert.EqualError(t, err, "error getting user ID user-id roles: unexpected error") + assert.Nil(t, user) + }) + + t.Run("gets user successfully", func(t *testing.T) { + token := "mytoken" + + u := &User{ + ID: "user-id", + FirstName: "First", + LastName: "Last", + Email: "email@email.com", + } + + jwtManagerMock. + On("ValidateToken", ctx, token). + Return(true, nil). + Once(). + On("GetUserFromToken", ctx, token). + Return(u, nil). + Once() + + authenticatorMock. + On("GetUser", ctx, u.ID). + Return(u, nil). + Once() + + roleManagerMock. + On("GetUserRoles", ctx, u). + Return([]string{"role1", "role2"}, nil). + Once() + + user, err := authManager.GetUser(ctx, token) + require.NoError(t, err) + + assert.Equal(t, u.ID, user.ID) + assert.Equal(t, u.FirstName, user.FirstName) + assert.Equal(t, u.LastName, user.LastName) + assert.Equal(t, u.Email, user.Email) + assert.Equal(t, []string{"role1", "role2"}, user.Roles) + }) + + authenticatorMock.AssertExpectations(t) + jwtManagerMock.AssertExpectations(t) + roleManagerMock.AssertExpectations(t) +} diff --git a/stellar-auth/pkg/auth/authenticator.go b/stellar-auth/pkg/auth/authenticator.go new file mode 100644 index 000000000..eff8e5920 --- /dev/null +++ b/stellar-auth/pkg/auth/authenticator.go @@ -0,0 +1,468 @@ +package auth + +import ( + "context" + "crypto/rand" + "database/sql" + "errors" + "fmt" + "math/big" + "strings" + "time" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/utils" +) + +var ( + ErrInvalidCredentials = errors.New("invalid credentials") + ErrNoRowsAffected = errors.New("no rows affected") + ErrInvalidResetPasswordToken = errors.New("invalid reset password token") + ErrUserNotFound = errors.New("user not found") + ErrUserEmailAlreadyExists = errors.New("a user with this email already exists") + ErrUserHasValidToken = errors.New("user has a valid token") +) + +const ( + resetTokenLength = 10 +) + +type Authenticator interface { + ValidateCredentials(ctx context.Context, email, password string) (*User, error) + // CreateUser creates a new user it receives a user object and the password + CreateUser(ctx context.Context, user *User, password string) (*User, error) + UpdateUser(ctx context.Context, ID, firstName, lastName, email, password string) error + ActivateUser(ctx context.Context, userID string) error + DeactivateUser(ctx context.Context, userID string) error + ForgotPassword(ctx context.Context, email string) (string, error) + ResetPassword(ctx context.Context, resetToken, password string) error + GetAllUsers(ctx context.Context) ([]User, error) + GetUser(ctx context.Context, userID string) (*User, error) +} + +type defaultAuthenticator struct { + dbConnectionPool db.DBConnectionPool + passwordEncrypter PasswordEncrypter + resetTokenExpirationHours time.Duration +} + +type authUser struct { + ID string `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string `db:"email"` + EncryptedPassword string `db:"encrypted_password"` +} + +func (a *defaultAuthenticator) ValidateCredentials(ctx context.Context, email, password string) (*User, error) { + const query = ` + SELECT + u.id, + u.first_name, + u.last_name, + u.encrypted_password + FROM + auth_users u + WHERE + email = $1 AND is_active = true + ` + + au := authUser{} + err := a.dbConnectionPool.GetContext(ctx, &au, query, email) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrInvalidCredentials + } + + return nil, fmt.Errorf("querying user: %w", err) + } + + isEqual, err := a.passwordEncrypter.ComparePassword(ctx, au.EncryptedPassword, password) + if err != nil { + return nil, fmt.Errorf("comparing password: %w", err) + } + if !isEqual { + return nil, ErrInvalidCredentials + } + + return &User{ + ID: au.ID, + Email: email, + FirstName: au.FirstName, + LastName: au.LastName, + }, nil +} + +// CreateUser creates a user in the database. If a empty password is passed by parameter, a random password is generated, +// so the user can go through the ForgotPassword flow. +func (a *defaultAuthenticator) CreateUser(ctx context.Context, user *User, password string) (*User, error) { + if err := user.Validate(); err != nil { + return nil, fmt.Errorf("error validating user fields: %w", err) + } + + // In case no password is passed we generate a random OTP (One Time Password) + if password == "" { + // Random length pasword + randomNumber, err := rand.Int(rand.Reader, big.NewInt(maxPasswordLength-minPasswordLength+1)) + if err != nil { + return nil, fmt.Errorf("error generating random number in create user: %w", err) + } + + passwordLength := int(randomNumber.Int64() + minPasswordLength) + password, err = utils.StringWithCharset(passwordLength, utils.PasswordCharset) + if err != nil { + return nil, fmt.Errorf("error generating random password string in create user: %w", err) + } + } + + encryptedPassword, err := a.passwordEncrypter.Encrypt(ctx, password) + if err != nil { + return nil, fmt.Errorf("error encrypting password: %w", err) + } + + const query = ` + INSERT INTO auth_users + (email, encrypted_password, first_name, last_name, roles, is_owner) + VALUES + ($1, $2, $3, $4, $5, $6) + RETURNING id + ` + + var userID string + err = a.dbConnectionPool.GetContext(ctx, &userID, query, user.Email, encryptedPassword, user.FirstName, user.LastName, pq.Array(user.Roles), user.IsOwner) + if err != nil { + if pqError, ok := err.(*pq.Error); ok && pqError.Constraint == "auth_users_email_key" { + return nil, ErrUserEmailAlreadyExists + } + return nil, fmt.Errorf("error inserting user: %w", err) + } + + user.ID = userID + user.IsActive = true + + return user, nil +} + +func (a *defaultAuthenticator) UpdateUser(ctx context.Context, ID, firstName, lastName, email, password string) error { + if firstName == "" && lastName == "" && email == "" && password == "" { + return fmt.Errorf("provide at least one of these values: firstName, lastName, email or password") + } + + query := ` + UPDATE + auth_users + SET + %s + WHERE id = ? + ` + + fields := []string{} + args := []interface{}{} + if firstName != "" { + fields = append(fields, "first_name = ?") + args = append(args, firstName) + } + + if lastName != "" { + fields = append(fields, "last_name = ?") + args = append(args, lastName) + } + + if email != "" { + if err := utils.ValidateEmail(email); err != nil { + return fmt.Errorf("error validating email: %w", err) + } + + fields = append(fields, "email = ?") + args = append(args, email) + } + + if password != "" { + encryptedPassword, err := a.passwordEncrypter.Encrypt(ctx, password) + if err != nil { + if !errors.Is(err, ErrPasswordTooShort) { + return fmt.Errorf("error encrypting password: %w", err) + } + return err + } + + fields = append(fields, "encrypted_password = ?") + args = append(args, encryptedPassword) + } + + query = a.dbConnectionPool.Rebind(fmt.Sprintf(query, strings.Join(fields, ", "))) + args = append(args, ID) + + res, err := a.dbConnectionPool.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("error updating user in the database: %w", err) + } + + numRowsAffected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("error getting the number of rows affected: %w", err) + } + if numRowsAffected == 0 { + return ErrNoRowsAffected + } + + return nil +} + +func (a *defaultAuthenticator) updateIsActive(ctx context.Context, userID string, isActive bool) error { + const query = "UPDATE auth_users SET is_active = $1 WHERE id = $2" + + result, err := a.dbConnectionPool.ExecContext(ctx, query, isActive, userID) + if err != nil { + return fmt.Errorf("error updating is_active for user ID %s: %w", userID, err) + } + + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + if numRowsAffected == 0 { + return ErrNoRowsAffected + } + + return nil +} + +func (a *defaultAuthenticator) ActivateUser(ctx context.Context, userID string) error { + err := a.updateIsActive(ctx, userID, true) + if err != nil { + return fmt.Errorf("error activating user ID %s: %w", userID, err) + } + + return nil +} + +func (a *defaultAuthenticator) DeactivateUser(ctx context.Context, userID string) error { + err := a.updateIsActive(ctx, userID, false) + if err != nil { + return fmt.Errorf("error deactivating user ID %s: %w", userID, err) + } + + return nil +} + +func (a *defaultAuthenticator) ForgotPassword(ctx context.Context, email string) (string, error) { + if email == "" { + return "", fmt.Errorf("error generating user reset password token: email cannot be empty") + } + + resetToken, err := utils.StringWithCharset(resetTokenLength, utils.DefaultCharset) + if err != nil { + return "", fmt.Errorf("error generating random reset token in forgot password: %w", err) + } + + checkValidTokenQuery := ` + SELECT EXISTS ( + SELECT 1 + FROM auth_user_password_reset ar + INNER JOIN auth_users au ON ar.auth_user_id = au.id + WHERE au.email = $1 + AND ar.is_valid = true + AND (ar.created_at + INTERVAL '20 minutes') > now() + ) + ` + var hasValidToken bool + err = a.dbConnectionPool.GetContext(ctx, &hasValidToken, checkValidTokenQuery, email) + if err != nil { + return "", fmt.Errorf("error checking if user has valid token: %w", err) + } + + if hasValidToken { + return "", ErrUserHasValidToken + } + + q := ` + WITH auth_user_reset_token_info AS ( + SELECT id, $2 as reset_token FROM auth_users WHERE email = $1 + ) + INSERT INTO + auth_user_password_reset (auth_user_id, token) + SELECT id, reset_token FROM auth_user_reset_token_info + ` + result, err := a.dbConnectionPool.ExecContext(ctx, q, email, resetToken) + if err != nil { + return "", fmt.Errorf("error inserting user reset password token in the database: %w", err) + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return "", fmt.Errorf("error getting rows affected inserting user reset password token in the database: %w", err) + } + if rowsAffected == 0 { + return "", ErrUserNotFound + } + + return resetToken, nil +} + +func (a *defaultAuthenticator) ResetPassword(ctx context.Context, resetToken, password string) error { + return db.RunInTransaction(ctx, a.dbConnectionPool, nil, func(dbTx db.DBTransaction) error { + query := ` + SELECT + auth_user_id, created_at + FROM + auth_user_password_reset + WHERE + token = $1 AND is_valid = true + ` + + type authUserPasswordReset struct { + UserID string `db:"auth_user_id"` + CreatedAt time.Time `db:"created_at"` + } + + var aupr authUserPasswordReset + err := dbTx.GetContext(ctx, &aupr, query, resetToken) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrInvalidResetPasswordToken + } + return fmt.Errorf("error searching password reset token for user in database: %w", err) + } + + // Token is only valid for 20 minutes + if aupr.CreatedAt.Add(time.Minute * 20).Before(time.Now()) { + return ErrInvalidResetPasswordToken + } + + encryptedPassword, err := a.passwordEncrypter.Encrypt(ctx, password) + if err != nil { + return fmt.Errorf("error trying to encrypt user password: %w", err) + } + + query = `UPDATE auth_users SET encrypted_password = $1 WHERE id = $2` + _, err = dbTx.ExecContext(ctx, query, encryptedPassword, aupr.UserID) + if err != nil { + return fmt.Errorf("error reseting user password in the database: %w", err) + } + + err = a.invalidateResetPasswordToken(ctx, dbTx, resetToken) + if err != nil { + return fmt.Errorf("error invalidating reset password token: %w", err) + } + + return nil + }) +} + +func (a *defaultAuthenticator) invalidateResetPasswordToken(ctx context.Context, dbTx db.DBTransaction, resetToken string) error { + q := "UPDATE auth_user_password_reset SET is_valid = false WHERE token = $1" + _, err := dbTx.ExecContext(ctx, q, resetToken) + if err != nil { + return fmt.Errorf("error invalidating reset password token in the database: %w", err) + } + + return nil +} + +func (a *defaultAuthenticator) GetAllUsers(ctx context.Context) ([]User, error) { + const query = ` + SELECT + id, + first_name, + last_name, + email, + roles, + is_owner, + is_active + FROM + auth_users + ` + + dbUsers := []struct { + ID string `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string `db:"email"` + Roles pq.StringArray `db:"roles"` + IsOwner bool `db:"is_owner"` + IsActive bool `db:"is_active"` + }{} + err := a.dbConnectionPool.SelectContext(ctx, &dbUsers, query) + if err != nil { + return nil, fmt.Errorf("error querying all users in the database: %w", err) + } + + users := []User{} + for _, dbUser := range dbUsers { + users = append(users, User{ + ID: dbUser.ID, + FirstName: dbUser.FirstName, + LastName: dbUser.LastName, + Email: dbUser.Email, + IsOwner: dbUser.IsOwner, + IsActive: dbUser.IsActive, + Roles: dbUser.Roles, + }) + } + + return users, nil +} + +func (a *defaultAuthenticator) GetUser(ctx context.Context, userID string) (*User, error) { + const query = ` + SELECT + first_name, + last_name, + email + FROM + auth_users + WHERE + id = $1 AND is_active = true + ` + + var u authUser + err := a.dbConnectionPool.GetContext(ctx, &u, query, userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("error querying user ID %s: %w", userID, err) + } + + return &User{ + ID: userID, + FirstName: u.FirstName, + LastName: u.LastName, + Email: u.Email, + }, nil +} + +type defaultAuthenticatorOption func(a *defaultAuthenticator) + +func newDefaultAuthenticator(options ...defaultAuthenticatorOption) *defaultAuthenticator { + authenticator := &defaultAuthenticator{} + + for _, option := range options { + option(authenticator) + } + + return authenticator +} + +func withAuthenticatorDatabaseConnectionPool(dbConnectionPool db.DBConnectionPool) defaultAuthenticatorOption { + return func(a *defaultAuthenticator) { + a.dbConnectionPool = dbConnectionPool + } +} + +func withPasswordEncrypter(passwordEncrypter PasswordEncrypter) defaultAuthenticatorOption { + return func(a *defaultAuthenticator) { + a.passwordEncrypter = passwordEncrypter + } +} + +func withResetTokenExpirationHours(expirationHours time.Duration) defaultAuthenticatorOption { + return func(a *defaultAuthenticator) { + a.resetTokenExpirationHours = expirationHours + } +} + +// Ensuring that defaultAuthenticator is implementing Authenticator interface +var _ Authenticator = (*defaultAuthenticator)(nil) diff --git a/stellar-auth/pkg/auth/authenticator_test.go b/stellar-auth/pkg/auth/authenticator_test.go new file mode 100644 index 000000000..cc0b160ce --- /dev/null +++ b/stellar-auth/pkg/auth/authenticator_test.go @@ -0,0 +1,868 @@ +package auth + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var errUnexpectedError = errors.New("unexpected error") + +func assertUserIsActive(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool, userID string, expectedIsActive bool) { + const query = "SELECT is_active FROM auth_users WHERE id = $1" + + var isActive bool + err := dbConnectionPool.GetContext(ctx, &isActive, query, userID) + require.NoError(t, err) + + assert.Equal(t, expectedIsActive, isActive) +} + +func Test_DefaultAuthenticator_ValidateCredential(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool), withPasswordEncrypter(passwordEncrypterMock)) + + ctx := context.Background() + + t.Run("returns error when email is not found", func(t *testing.T) { + email, pass := "email@email.com", "pass1234" + + user, err := authenticator.ValidateCredentials(ctx, email, pass) + + assert.EqualError(t, err, ErrInvalidCredentials.Error()) + assert.Nil(t, user) + }) + + t.Run("returns error when Password Encrypter fails comparing password and hash", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + + password := "wrongpassword" + passwordEncrypterMock. + On("ComparePassword", ctx, randUser.EncryptedPassword, password). + Return(false, errUnexpectedError). + Once() + + user, err := authenticator.ValidateCredentials(ctx, randUser.Email, password) + + assert.EqualError(t, err, "comparing password: unexpected error") + assert.Nil(t, user) + }) + + t.Run("returns error when password is wrong", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + + password := "wrongpassword" + passwordEncrypterMock. + On("ComparePassword", ctx, randUser.EncryptedPassword, password). + Return(false, nil). + Once() + + user, err := authenticator.ValidateCredentials(ctx, randUser.Email, password) + + assert.EqualError(t, err, ErrInvalidCredentials.Error()) + assert.Nil(t, user) + }) + + t.Run("returns error when user is not active", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + err := authenticator.updateIsActive(ctx, randUser.ID, false) + require.NoError(t, err) + + user, err := authenticator.ValidateCredentials(ctx, randUser.Email, randUser.Password) + + assert.EqualError(t, err, ErrInvalidCredentials.Error()) + assert.Nil(t, user) + }) + + t.Run("returns user successfully", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + + passwordEncrypterMock. + On("ComparePassword", ctx, randUser.EncryptedPassword, randUser.Password). + Return(true, nil). + Once() + + user, err := authenticator.ValidateCredentials(ctx, randUser.Email, randUser.Password) + require.NoError(t, err) + + assert.Equal(t, randUser.Email, user.Email) + assert.Equal(t, randUser.ID, user.ID) + assert.Equal(t, randUser.FirstName, user.FirstName) + assert.Equal(t, randUser.LastName, user.LastName) + }) + + passwordEncrypterMock.AssertExpectations(t) +} + +func Test_DefaultAuthenticator_CreateUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool), withPasswordEncrypter(passwordEncrypterMock)) + + ctx := context.Background() + + t.Run("returns error when user is not valid", func(t *testing.T) { + user := &User{ + Email: "", + FirstName: "", + LastName: "", + } + + password := "mysecret" + + // Email + u, err := authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, "error validating user fields: email is required") + + user.Email = "invalid" + u, err = authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, `error validating user fields: email is invalid: the provided email "invalid" is not valid`) + + // First name + user.Email = "email@email.com" + u, err = authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, "error validating user fields: first name is required") + + // Last name + user.FirstName = "First" + u, err = authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, "error validating user fields: last name is required") + }) + + t.Run("returns error when password is invalid", func(t *testing.T) { + user := &User{ + Email: "email@email.com", + FirstName: "First", + LastName: "Last", + } + + password := "secret" + + passwordEncrypterMock. + On("Encrypt", ctx, password). + Return("", ErrPasswordTooShort). + Once() + + u, err := authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, "error encrypting password: password should have at least 8 characters") + + passwordEncrypterMock. + On("Encrypt", ctx, password). + Return("", errUnexpectedError). + Once() + + u, err = authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, "error encrypting password: unexpected error") + }) + + t.Run("returns error when user is duplicated", func(t *testing.T) { + user := &User{ + Email: "email@email.com", + FirstName: "First", + LastName: "Last", + } + + password := "mysecret" + + passwordEncrypterMock. + On("Encrypt", ctx, password). + Return("encrypted", nil). + Twice() + + _, err := authenticator.CreateUser(ctx, user, password) + require.NoError(t, err) + + u, err := authenticator.CreateUser(ctx, user, password) + + assert.Nil(t, u) + assert.EqualError(t, err, ErrUserEmailAlreadyExists.Error()) + }) + + t.Run("creates a new user correctly", func(t *testing.T) { + user := &User{ + Email: "email-test@email.com", + FirstName: "First", + LastName: "Last", + } + + password := "mysecret" + + passwordEncrypterMock. + On("Encrypt", ctx, password). + Return("encryptedpassword", nil). + Once() + + u, err := authenticator.CreateUser(ctx, user, password) + require.NoError(t, err) + + const query = "SELECT id, email, first_name, last_name, encrypted_password, is_active FROM auth_users WHERE email = $1" + + var newUser User + var encryptedPassword string + err = dbConnectionPool.QueryRowxContext(ctx, query, user.Email).Scan(&newUser.ID, &newUser.Email, &newUser.FirstName, &newUser.LastName, &encryptedPassword, &newUser.IsActive) + require.NoError(t, err) + + assert.Equal(t, newUser.ID, u.ID) + assert.Equal(t, newUser.Email, u.Email) + assert.Equal(t, newUser.FirstName, u.FirstName) + assert.Equal(t, newUser.LastName, u.LastName) + assert.Equal(t, newUser.IsActive, u.IsActive) + assert.Equal(t, "encryptedpassword", encryptedPassword) + }) + + t.Run("creates a user successfully with an OTP", func(t *testing.T) { + user := &User{ + Email: "emailotp@email.com", + FirstName: "First", + LastName: "Last", + } + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encryptedpassword", nil). + Once() + + u, err := authenticator.CreateUser(ctx, user, "") + require.NoError(t, err) + + const query = "SELECT id, email, first_name, last_name, encrypted_password FROM auth_users WHERE email = $1" + + var newUser User + var encryptedPassword string + err = dbConnectionPool.QueryRowxContext(ctx, query, user.Email).Scan(&newUser.ID, &newUser.Email, &newUser.FirstName, &newUser.LastName, &encryptedPassword) + require.NoError(t, err) + + assert.Equal(t, newUser.ID, u.ID) + assert.Equal(t, newUser.Email, u.Email) + assert.Equal(t, newUser.FirstName, u.FirstName) + assert.Equal(t, newUser.LastName, u.LastName) + assert.Equal(t, "encryptedpassword", encryptedPassword) + }) + + passwordEncrypterMock.AssertExpectations(t) +} + +func Test_DefaultAuthenticator_ActivateUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool)) + + ctx := context.Background() + + t.Run("returns error when user does not exist", func(t *testing.T) { + err = authenticator.ActivateUser(ctx, "user-id") + assert.EqualError(t, err, "error activating user ID user-id: no rows affected") + }) + + t.Run("activate user correctly", func(t *testing.T) { + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + err := authenticator.updateIsActive(ctx, randUser.ID, false) + require.NoError(t, err) + assertUserIsActive(t, ctx, dbConnectionPool, randUser.ID, false) + + err = authenticator.ActivateUser(ctx, randUser.ID) + require.NoError(t, err) + assertUserIsActive(t, ctx, dbConnectionPool, randUser.ID, true) + }) +} + +func Test_DefaultAuthenticator_DeactivateUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool)) + + ctx := context.Background() + + t.Run("returns error when user does not exist", func(t *testing.T) { + err = authenticator.DeactivateUser(ctx, "user-id") + assert.EqualError(t, err, "error deactivating user ID user-id: no rows affected") + }) + + t.Run("deactivate user correctly", func(t *testing.T) { + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + assertUserIsActive(t, ctx, dbConnectionPool, randUser.ID, true) + + err = authenticator.DeactivateUser(ctx, randUser.ID) + require.NoError(t, err) + assertUserIsActive(t, ctx, dbConnectionPool, randUser.ID, false) + }) +} + +func Test_DefaultAuthenticator_invalidateResetPasswordToken(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool), withPasswordEncrypter(passwordEncrypterMock)) + + ctx := context.Background() + + t.Run("Should change status of the token to invalid", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + token := CreateResetPasswordTokenFixture(t, ctx, dbConnectionPool, randUser, true, time.Now()) + + dbTx, err := dbConnectionPool.BeginTxx(ctx, nil) + require.NoError(t, err) + + err = authenticator.invalidateResetPasswordToken(ctx, dbTx, token) + require.NoError(t, err) + + err = dbTx.Commit() + require.NoError(t, err) + + var dbToken string + q := "SELECT token FROM auth_user_password_reset WHERE token = $1 AND is_valid = true" + err = dbConnectionPool.GetContext(ctx, &dbToken, q, token) + require.EqualError(t, err, sql.ErrNoRows.Error()) + require.Empty(t, dbToken) + }) + + passwordEncrypterMock.AssertExpectations(t) +} + +func Test_DefaultAuthenticator_ResetPassword(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool), withPasswordEncrypter(passwordEncrypterMock)) + + ctx := context.Background() + + t.Run("Should treat encrypt password error", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + newPassword := "new_not_encrypted_pass" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once(). + On("Encrypt", ctx, newPassword). + Return("", errUnexpectedError). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + token := CreateResetPasswordTokenFixture(t, ctx, dbConnectionPool, randUser, true, time.Now()) + + err := authenticator.ResetPassword(ctx, token, newPassword) + assert.EqualError(t, err, "running atomic function in RunInTransactionWithResult: error trying to encrypt user password: unexpected error") + }) + + t.Run("Should treat a not found token error", func(t *testing.T) { + err := authenticator.ResetPassword(ctx, "notfoundtoken", "newpassword") + assert.EqualError(t, err, "running atomic function in RunInTransactionWithResult: "+ErrInvalidResetPasswordToken.Error()) + }) + + t.Run("Should reset the password with a valid token, and make the token invalid after", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + newPassword := "new_not_encrypted_pass" + newEncryptedPassword := "newencryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once(). + On("Encrypt", ctx, newPassword). + Return(newEncryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + token := CreateResetPasswordTokenFixture(t, ctx, dbConnectionPool, randUser, true, time.Now()) + + err := authenticator.ResetPassword(ctx, token, newPassword) + require.NoError(t, err) + + // Token should be invalid after + var dbIsValid bool + q := `SELECT is_valid FROM auth_user_password_reset WHERE token = $1` + err = dbConnectionPool.GetContext(ctx, &dbIsValid, q, token) + require.NoError(t, err) + assert.False(t, dbIsValid) + + // User should have a new password encrypted + var expectedNewEncryptedPass string + q = `SELECT encrypted_password FROM auth_users WHERE id = $1` + err = dbConnectionPool.GetContext(ctx, &expectedNewEncryptedPass, q, randUser.ID) + require.NoError(t, err) + assert.Equal(t, expectedNewEncryptedPass, newEncryptedPassword) + }) + + t.Run("Should return an error with an expired token", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + newPassword := "new_not_encrypted_pass" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + token := CreateResetPasswordTokenFixture(t, ctx, dbConnectionPool, randUser, true, time.Now().Add(-time.Hour*25)) + + err := authenticator.ResetPassword(ctx, token, newPassword) + require.EqualError(t, err, "running atomic function in RunInTransactionWithResult: "+ErrInvalidResetPasswordToken.Error()) + }) + + passwordEncrypterMock.AssertExpectations(t) +} + +func Test_DefaultAuthenticator_ForgotPassword(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool), withPasswordEncrypter(passwordEncrypterMock)) + + ctx := context.Background() + + t.Run("Should return an error if the email is empty", func(t *testing.T) { + resetToken, err := authenticator.ForgotPassword(ctx, "") + assert.EqualError(t, err, "error generating user reset password token: email cannot be empty") + assert.Empty(t, resetToken) + }) + + t.Run("Should return an error if the user is not found", func(t *testing.T) { + resetToken, err := authenticator.ForgotPassword(ctx, "notfounduser@email.com") + assert.EqualError(t, err, ErrUserNotFound.Error()) + assert.Empty(t, resetToken) + }) + + t.Run("should return an error if user has valid token", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + + resetToken, err := authenticator.ForgotPassword(ctx, randUser.Email) + require.NoError(t, err) + assert.NotEmpty(t, resetToken) + + resetTokenFail1, err := authenticator.ForgotPassword(ctx, randUser.Email) + require.EqualError(t, err, "user has a valid token") + assert.Empty(t, resetTokenFail1) + + updateTokenQuery := ` + UPDATE auth_user_password_reset + SET created_at = (created_at - INTERVAL '19 minutes') + WHERE token = $1 + ` + _, err = dbConnectionPool.ExecContext(ctx, updateTokenQuery, resetToken) + require.NoError(t, err) + + resetTokenFail2, err := authenticator.ForgotPassword(ctx, randUser.Email) + require.EqualError(t, err, "user has a valid token") + assert.Empty(t, resetTokenFail2) + }) + + t.Run("should return reset token when previous token is expired", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + + oldResetToken, err := authenticator.ForgotPassword(ctx, randUser.Email) + require.NoError(t, err) + assert.NotEmpty(t, oldResetToken) + + // Expire old token + updateTokenQuery := ` + UPDATE auth_user_password_reset + SET created_at = (created_at - INTERVAL '20 minutes') + WHERE token = $1 + ` + _, err = dbConnectionPool.ExecContext(ctx, updateTokenQuery, oldResetToken) + require.NoError(t, err) + + newResetToken, err := authenticator.ForgotPassword(ctx, randUser.Email) + require.NoError(t, err) + assert.NotEmpty(t, newResetToken) + assert.NotEqual(t, oldResetToken, newResetToken) + }) + + t.Run("Should return reset token with a valid user", func(t *testing.T) { + encryptedPassword := "encryptedpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return(encryptedPassword, nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + + resetToken, err := authenticator.ForgotPassword(ctx, randUser.Email) + require.NoError(t, err) + + assert.NotEmpty(t, resetToken) + }) + + passwordEncrypterMock.AssertExpectations(t) +} + +func Test_withResetTokenExpirationHours(t *testing.T) { + authenticator := newDefaultAuthenticator(withResetTokenExpirationHours(time.Hour * 24)) + assert.Equal(t, time.Hour*24, authenticator.resetTokenExpirationHours) + + authenticator = newDefaultAuthenticator(withResetTokenExpirationHours(time.Minute * 30)) + assert.Equal(t, time.Minute*30, authenticator.resetTokenExpirationHours) +} + +func Test_DefaultAuthenticator_GetAllUsers(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool)) + + ctx := context.Background() + + t.Run("returns an empty array if no users are registered", func(t *testing.T) { + users, err := authenticator.GetAllUsers(ctx) + require.NoError(t, err) + + assert.Empty(t, users) + }) + + t.Run("gets all users successfully", func(t *testing.T) { + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encryptedPassword", nil) + + randUser1 := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false, "role1", "role2") + randUser2 := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, true, "role1", "role2") + randUser3 := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false, "role3") + + users, err := authenticator.GetAllUsers(ctx) + require.NoError(t, err) + + expectedUsers := []User{ + *randUser1.ToUser(), + *randUser2.ToUser(), + *randUser3.ToUser(), + } + + assert.Equal(t, expectedUsers, users) + }) + + passwordEncrypterMock.AssertExpectations(t) +} + +func Test_DefaultAuthenticator_UpdateUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator( + withAuthenticatorDatabaseConnectionPool(dbConnectionPool), + withPasswordEncrypter(passwordEncrypterMock), + ) + + ctx := context.Background() + + type dbUser struct { + ID string `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string `db:"email"` + EncryptedPassword string `db:"encrypted_password"` + } + + getUser := func(t *testing.T, ctx context.Context, ID string) *dbUser { + const query = ` + SELECT id, first_name, last_name, email, encrypted_password FROM auth_users WHERE id = $1 + ` + var u dbUser + err := dbConnectionPool.GetContext(ctx, &u, query, ID) + require.NoError(t, err) + + return &u + } + + t.Run("returns error when no value is provided", func(t *testing.T) { + err := authenticator.UpdateUser(ctx, "user-id", "", "", "", "") + assert.EqualError(t, err, "provide at least one of these values: firstName, lastName, email or password") + }) + + t.Run("returns error when email is invalid", func(t *testing.T) { + err := authenticator.UpdateUser(ctx, "user-id", "", "", "invalid", "") + assert.EqualError(t, err, `error validating email: the provided email "invalid" is not valid`) + }) + + t.Run("returns error when password is too short", func(t *testing.T) { + password := "short" + + passwordEncrypterMock. + On("Encrypt", ctx, password). + Return("", ErrPasswordTooShort). + Once() + + err := authenticator.UpdateUser(ctx, "user-id", "", "", "", "short") + assert.EqualError(t, err, "password should have at least 8 characters") + }) + + t.Run("returns error when PasswordEncrypter fails", func(t *testing.T) { + password := "short" + + passwordEncrypterMock. + On("Encrypt", ctx, password). + Return("", errUnexpectedError). + Once() + + err := authenticator.UpdateUser(ctx, "user-id", "", "", "", "short") + assert.EqualError(t, err, "error encrypting password: unexpected error") + }) + + t.Run("updates first name successfully", func(t *testing.T) { + firstName := "FirstName" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encrypted", nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + assert.NotEqual(t, firstName, randUser.FirstName) + + err := authenticator.UpdateUser(ctx, randUser.ID, firstName, "", "", "") + require.NoError(t, err) + + u := getUser(t, ctx, randUser.ID) + + assert.Equal(t, firstName, u.FirstName) + assert.Equal(t, randUser.LastName, u.LastName) + assert.Equal(t, randUser.Email, u.Email) + assert.Equal(t, randUser.EncryptedPassword, u.EncryptedPassword) + }) + + t.Run("updates last name successfully", func(t *testing.T) { + lastName := "LastName" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encrypted", nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + assert.NotEqual(t, lastName, randUser.LastName) + + err := authenticator.UpdateUser(ctx, randUser.ID, "", lastName, "", "") + require.NoError(t, err) + + u := getUser(t, ctx, randUser.ID) + + assert.Equal(t, lastName, u.LastName) + assert.Equal(t, randUser.FirstName, u.FirstName) + assert.Equal(t, randUser.Email, u.Email) + assert.Equal(t, randUser.EncryptedPassword, u.EncryptedPassword) + }) + + t.Run("updates email successfully", func(t *testing.T) { + email := "email@email.com" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encrypted", nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + assert.NotEqual(t, email, randUser.Email) + + err := authenticator.UpdateUser(ctx, randUser.ID, "", "", email, "") + require.NoError(t, err) + + u := getUser(t, ctx, randUser.ID) + + assert.Equal(t, email, u.Email) + assert.Equal(t, randUser.FirstName, u.FirstName) + assert.Equal(t, randUser.LastName, u.LastName) + assert.Equal(t, randUser.EncryptedPassword, u.EncryptedPassword) + }) + + t.Run("updates password successfully", func(t *testing.T) { + password := "newpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encrypted", nil). + Once(). + On("Encrypt", ctx, password). + Return("newpassowrdencrypted", nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + assert.NotEqual(t, "newpassowrdencrypted", randUser.EncryptedPassword) + + err := authenticator.UpdateUser(ctx, randUser.ID, "", "", "", password) + require.NoError(t, err) + + u := getUser(t, ctx, randUser.ID) + + assert.Equal(t, "newpassowrdencrypted", u.EncryptedPassword) + assert.Equal(t, randUser.FirstName, u.FirstName) + assert.Equal(t, randUser.LastName, u.LastName) + assert.Equal(t, randUser.Email, u.Email) + }) + + t.Run("updates all fields successfully", func(t *testing.T) { + firstName, lastName, email, password := "FirstName", "LastName", "new_email@email.com", "newpassword" + + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encrypted", nil). + Once(). + On("Encrypt", ctx, password). + Return("newpassowrdencrypted", nil). + Once() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false) + assert.NotEqual(t, firstName, randUser.FirstName) + assert.NotEqual(t, lastName, randUser.LastName) + assert.NotEqual(t, email, randUser.Email) + assert.NotEqual(t, "newpassowrdencrypted", randUser.EncryptedPassword) + + err := authenticator.UpdateUser(ctx, randUser.ID, firstName, lastName, email, password) + require.NoError(t, err) + + u := getUser(t, ctx, randUser.ID) + + assert.Equal(t, firstName, u.FirstName) + assert.Equal(t, lastName, u.LastName) + assert.Equal(t, email, u.Email) + assert.Equal(t, "newpassowrdencrypted", u.EncryptedPassword) + }) +} + +func Test_DefaultAuthenticator_GetUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + passwordEncrypterMock := &PasswordEncrypterMock{} + authenticator := newDefaultAuthenticator(withAuthenticatorDatabaseConnectionPool(dbConnectionPool)) + + ctx := context.Background() + + t.Run("returns error when user is not found", func(t *testing.T) { + user, err := authenticator.GetUser(ctx, "user-id") + assert.ErrorIs(t, err, ErrUserNotFound) + assert.Nil(t, user) + }) + + t.Run("returns user successfully", func(t *testing.T) { + passwordEncrypterMock. + On("Encrypt", ctx, mock.AnythingOfType("string")). + Return("encryptedPassword", nil) + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, passwordEncrypterMock, false, "role1") + + u, err := authenticator.GetUser(ctx, randUser.ID) + require.NoError(t, err) + + assert.Equal(t, randUser.ID, u.ID) + assert.Equal(t, randUser.FirstName, u.FirstName) + assert.Equal(t, randUser.LastName, u.LastName) + assert.Equal(t, randUser.Email, u.Email) + }) +} diff --git a/stellar-auth/pkg/auth/fixtures.go b/stellar-auth/pkg/auth/fixtures.go new file mode 100644 index 000000000..033e7a329 --- /dev/null +++ b/stellar-auth/pkg/auth/fixtures.go @@ -0,0 +1,104 @@ +package auth + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "testing" + "time" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/utils" + "github.com/stretchr/testify/require" +) + +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +type RandomAuthUser struct { + ID string + Email string + FirstName string + LastName string + Password string + EncryptedPassword string + IsOwner bool + IsActive bool + Roles []string + CreatedAt time.Time +} + +func (rau *RandomAuthUser) ToUser() *User { + return &User{ + ID: rau.ID, + FirstName: rau.FirstName, + LastName: rau.LastName, + Email: rau.Email, + IsOwner: rau.IsOwner, + IsActive: rau.IsActive, + Roles: rau.Roles, + } +} + +func randStringRunes(t *testing.T, n int) string { + b := make([]rune, n) + for i := range b { + randomNumber, err := rand.Int(rand.Reader, big.NewInt(int64(len(letterRunes)))) + require.NoError(t, err) + + b[i] = letterRunes[randomNumber.Int64()] + } + return string(b) +} + +func CreateRandomAuthUserFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, passwordEncrypter PasswordEncrypter, isAdmin bool, roles ...string) *RandomAuthUser { + randomSuffix := randStringRunes(t, 5) + email := fmt.Sprintf("email%s@randomemail.com", randomSuffix) + password := "password" + randomSuffix + firstName := "firstName" + randomSuffix + lastName := "lastName" + randomSuffix + + encryptedPassword, err := passwordEncrypter.Encrypt(ctx, password) + require.NoError(t, err) + + const query = ` + INSERT INTO auth_users + (email, encrypted_password, is_owner, roles, first_name, last_name) + VALUES + ($1, $2, $3, $4, $5, $6) + RETURNING + id, created_at + ` + + user := &RandomAuthUser{ + Email: email, + FirstName: firstName, + LastName: lastName, + Password: password, + IsOwner: isAdmin, + IsActive: true, + EncryptedPassword: encryptedPassword, + Roles: roles, + } + err = sqlExec.QueryRowxContext(ctx, query, email, encryptedPassword, isAdmin, pq.Array(roles), firstName, lastName).Scan(&user.ID, &user.CreatedAt) + require.NoError(t, err) + + return user +} + +func CreateResetPasswordTokenFixture(t *testing.T, ctx context.Context, sqlExec db.SQLExecuter, randomAuthUser *RandomAuthUser, isValid bool, createdAt time.Time) (token string) { + resetToken, err := utils.StringWithCharset(resetTokenLength, utils.DefaultCharset) + require.NoError(t, err) + + q := ` + INSERT INTO + auth_user_password_reset (token, auth_user_id, is_valid, created_at) + VALUES + ($1, $2, $3, $4) + ` + _, err = sqlExec.ExecContext(ctx, q, resetToken, randomAuthUser.ID, isValid, createdAt) + require.NoError(t, err) + + return resetToken +} diff --git a/stellar-auth/pkg/auth/jwt_manager.go b/stellar-auth/pkg/auth/jwt_manager.go new file mode 100644 index 000000000..ff78cb3b0 --- /dev/null +++ b/stellar-auth/pkg/auth/jwt_manager.go @@ -0,0 +1,142 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "time" + + jwtgo "github.com/golang-jwt/jwt/v4" +) + +const defaultRefreshTimeout = 30 + +type JWTManager interface { + GenerateToken(ctx context.Context, user *User, expiresAt time.Time) (string, error) + RefreshToken(ctx context.Context, token string, expiresAt time.Time) (string, error) + ValidateToken(ctx context.Context, token string) (bool, error) + GetUserFromToken(ctx context.Context, token string) (*User, error) +} + +type claims struct { + User *User `json:"user"` + jwtgo.RegisteredClaims +} + +// defaultJWTManager +type defaultJWTManager struct { + privateKey string + publicKey string +} + +func (m *defaultJWTManager) parseToken(tokenString string) (*jwtgo.Token, *claims, error) { + c := &claims{} + token, err := jwtgo.ParseWithClaims(tokenString, c, func(t *jwtgo.Token) (interface{}, error) { + esPublicKey, err := jwtgo.ParseECPublicKeyFromPEM([]byte(m.publicKey)) + if err != nil { + return nil, fmt.Errorf("parsing EC Public Key: %w", err) + } + + return esPublicKey, nil + }) + if err != nil { + vErr, ok := err.(*jwtgo.ValidationError) + if !ok { + return nil, nil, fmt.Errorf("parsing token: %w", err) + } + + if vErr.Errors == jwtgo.ValidationErrorUnverifiable { + return nil, nil, fmt.Errorf("invalid key: %w", err) + } + + return nil, nil, ErrInvalidToken + } + + return token, c, nil +} + +func (m *defaultJWTManager) GenerateToken(ctx context.Context, user *User, expiresAt time.Time) (string, error) { + esPrivateKey, err := jwtgo.ParseECPrivateKeyFromPEM([]byte(m.privateKey)) + if err != nil { + return "", fmt.Errorf("parsing EC Private Key: %w", err) + } + + c := &claims{ + User: user, + RegisteredClaims: jwtgo.RegisteredClaims{ + ExpiresAt: jwtgo.NewNumericDate(expiresAt), + }, + } + + token := jwtgo.NewWithClaims(jwtgo.SigningMethodES256, c) + + tokenString, err := token.SignedString(esPrivateKey) + if err != nil { + return "", fmt.Errorf("signing token: %w", err) + } + + return tokenString, nil +} + +func (m *defaultJWTManager) RefreshToken(ctx context.Context, tokenString string, expiresAt time.Time) (string, error) { + _, c, err := m.parseToken(tokenString) + if err != nil { + return "", fmt.Errorf("parsing token to be refreshed: %w", err) + } + + // We only generate new tokens when enough time + // is elapsed. + if time.Until(c.ExpiresAt.Time) > defaultRefreshTimeout*time.Second { + return tokenString, nil + } + + tokenString, err = m.GenerateToken(ctx, c.User, expiresAt) + if err != nil { + return "", fmt.Errorf("generating new refreshed token: %w", err) + } + + return tokenString, nil +} + +func (m *defaultJWTManager) ValidateToken(ctx context.Context, tokenString string) (bool, error) { + token, _, err := m.parseToken(tokenString) + if errors.Is(err, ErrInvalidToken) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("parsing token to be validated: %w", err) + } + + return token.Valid, nil +} + +func (m *defaultJWTManager) GetUserFromToken(ctx context.Context, tokenString string) (*User, error) { + _, c, err := m.parseToken(tokenString) + if err != nil { + return nil, fmt.Errorf("parsing token to be validated: %w", err) + } + + return c.User, nil +} + +type defaultJWTManagerOption func(m *defaultJWTManager) + +func newDefaultJWTManager(options ...defaultJWTManagerOption) *defaultJWTManager { + jwtManager := &defaultJWTManager{} + + for _, option := range options { + option(jwtManager) + } + + return jwtManager +} + +func withECKeypair(publicKey string, privateKey string) defaultJWTManagerOption { + return func(m *defaultJWTManager) { + m.publicKey = publicKey + m.privateKey = privateKey + } +} + +// Ensuring that defaultJWTManager is implementing JWTManager interface +var _ JWTManager = (*defaultJWTManager)(nil) diff --git a/stellar-auth/pkg/auth/jwt_manager_test.go b/stellar-auth/pkg/auth/jwt_manager_test.go new file mode 100644 index 000000000..1a9d5fd4c --- /dev/null +++ b/stellar-auth/pkg/auth/jwt_manager_test.go @@ -0,0 +1,183 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// NEVER use these values in production! +var ( + testPrivateKey = `-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgaWqFzmxoHbYUbZEm +EO5XNy9QX3cTAh2jtEi+lOJsnEihRANCAAQ0VOBzsDLy4rqNM5G/Go6IBrRIV7Er +Aftohtbum9ABi8CEq05EzjTGf/D8pzW5RXOhgQhm3jGVv4/fzAtTtunR +-----END PRIVATE KEY-----` + testPublicKey = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAENFTgc7Ay8uK6jTORvxqOiAa0SFex +KwH7aIbW7pvQAYvAhKtORM40xn/w/Kc1uUVzoYEIZt4xlb+P38wLU7bp0Q== +-----END PUBLIC KEY-----` +) + +func Test_DefaultJWTManager_GenerateToken(t *testing.T) { + ctx := context.Background() + + t.Run("returns error when the EC Private Key is invalid", func(t *testing.T) { + jwtManager := newDefaultJWTManager(withECKeypair(testPublicKey, "invalid")) + + expiresAt := time.Now().Add(time.Minute * 5) + token, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + + assert.EqualError(t, err, "parsing EC Private Key: invalid key: Key must be a PEM encoded PKCS1 or PKCS8 key") + assert.Empty(t, token) + }) + + t.Run("generates token correctly", func(t *testing.T) { + jwtManager := newDefaultJWTManager(withECKeypair(testPublicKey, testPrivateKey)) + + expiresAt := time.Now().Add(time.Minute * 5) + token, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + require.NoError(t, err) + + assert.NotEmpty(t, token) + }) +} + +func Test_DefaultJWTManager_ValidateToken(t *testing.T) { + jwtManager := newDefaultJWTManager(withECKeypair(testPublicKey, testPrivateKey)) + + ctx := context.Background() + + t.Run("returns false when token has a invalid signature", func(t *testing.T) { + invalidSignatureToken := "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjp7ImlkIjoidXNlci1pZCIsImVtYWlsIjoiZW1haWxAZW1haWwuY29tIiwicm9sZXMiOlt7Im5hbWUiOiJTdXBlcnZpc29yIn1dfSwiZXhwIjoxNjc1OTYyOTQ3fQ.zK9Jb5EMl5rOTOO18SM-q_WOtD0TbL0f9cFfilW9tWHa_vjVMEaf6xRjold9dTPLICDBrqdw_luhKlT370EAiA" + + isValid, err := jwtManager.ValidateToken(ctx, invalidSignatureToken) + require.NoError(t, err) + + assert.False(t, isValid) + }) + + t.Run("returns false when token is expired", func(t *testing.T) { + expiresAt := time.Now().Add(time.Minute * -5) + token, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + require.NoError(t, err) + + isValid, err := jwtManager.ValidateToken(ctx, token) + require.NoError(t, err) + + assert.False(t, isValid) + }) + + t.Run("returns false when token has invalid segments", func(t *testing.T) { + isValid, err := jwtManager.ValidateToken(ctx, "token") + require.NoError(t, err) + + assert.False(t, isValid) + }) + + t.Run("returns true when token is valid", func(t *testing.T) { + expiresAt := time.Now().Add(time.Minute * 5) + token, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + require.NoError(t, err) + + isValid, err := jwtManager.ValidateToken(ctx, token) + require.NoError(t, err) + + assert.True(t, isValid) + }) +} + +func Test_DefaultJWTManager_RefreshToken(t *testing.T) { + jwtManager := newDefaultJWTManager(withECKeypair(testPublicKey, testPrivateKey)) + + ctx := context.Background() + + t.Run("returns the same token when is above 30 secs until expires", func(t *testing.T) { + expiresAt := time.Now().Add(time.Second * 31) + token, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + require.NoError(t, err) + + newExpiresAt := time.Now().Add(time.Minute * 5) + refreshedToken, err := jwtManager.RefreshToken(ctx, token, newExpiresAt) + require.NoError(t, err) + + assert.Equal(t, token, refreshedToken) + }) + + t.Run("returns a refreshed token", func(t *testing.T) { + expiresAt := time.Now().Add(time.Second * 30) + token, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + require.NoError(t, err) + + newExpiresAt := time.Now().Add(time.Minute * 5) + refreshedToken, err := jwtManager.RefreshToken(ctx, token, newExpiresAt) + require.NoError(t, err) + + assert.NotEqual(t, token, refreshedToken) + }) +} + +func Test_DefaultJWTManager_parseToken(t *testing.T) { + ctx := context.Background() + + t.Run("returns error when the EC Public Key is invalid", func(t *testing.T) { + jwtManager := newDefaultJWTManager(withECKeypair("invalid", testPrivateKey)) + + expiresAt := time.Now().Add(time.Minute * 5) + tokenString, err := jwtManager.GenerateToken(ctx, &User{}, expiresAt) + require.NoError(t, err) + + token, c, err := jwtManager.parseToken(tokenString) + + assert.EqualError(t, err, "invalid key: parsing EC Public Key: invalid key: Key must be a PEM encoded PKCS1 or PKCS8 key") + assert.Nil(t, token) + assert.Nil(t, c) + }) + + t.Run("returns token and claims correctly", func(t *testing.T) { + jwtManager := newDefaultJWTManager(withECKeypair(testPublicKey, testPrivateKey)) + + expectedUser := &User{ + ID: "user-ID", + Email: "email@email.com", + Roles: []string{ + "role1", + }, + } + + expiresAt := time.Now().Add(time.Minute * 5).Truncate(time.Second) + tokenString, err := jwtManager.GenerateToken(ctx, expectedUser, expiresAt) + require.NoError(t, err) + + token, c, err := jwtManager.parseToken(tokenString) + require.NoError(t, err) + + assert.Equal(t, expectedUser, c.User) + assert.Equal(t, expiresAt, c.ExpiresAt.Time) + assert.Equal(t, tokenString, token.Raw) + }) +} + +func Test_DefaultJWTManager_GetUserFromToken(t *testing.T) { + ctx := context.Background() + + jwtManager := newDefaultJWTManager(withECKeypair(testPublicKey, testPrivateKey)) + + expectedUser := &User{ + ID: "user-id", + Email: "email@email.com", + Roles: []string{"role1", "role2"}, + } + + expiresAt := time.Now().Add(time.Minute * 5).Truncate(time.Second) + token, err := jwtManager.GenerateToken(ctx, expectedUser, expiresAt) + require.NoError(t, err) + + gotUser, err := jwtManager.GetUserFromToken(ctx, token) + require.NoError(t, err) + + assert.Equal(t, expectedUser, gotUser) +} diff --git a/stellar-auth/pkg/auth/manager.go b/stellar-auth/pkg/auth/manager.go new file mode 100644 index 000000000..0db24c9d2 --- /dev/null +++ b/stellar-auth/pkg/auth/manager.go @@ -0,0 +1,135 @@ +package auth + +import ( + "fmt" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/utils" +) + +const defaultExpirationTimeInMinutes = 15 + +type User struct { + ID string `json:"id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + IsOwner bool `json:"-"` + IsActive bool `json:"is_active"` + Roles []string `json:"roles"` +} + +func (u *User) Validate() error { + if u.Email == "" { + return fmt.Errorf("email is required") + } else if err := utils.ValidateEmail(u.Email); err != nil { + return fmt.Errorf("email is invalid: %w", err) + } + + if u.FirstName == "" { + return fmt.Errorf("first name is required") + } + + if u.LastName == "" { + return fmt.Errorf("last name is required") + } + + return nil +} + +// AuthManager manages the JWT token generation, validation and refresh. Use `NewAuthManager` function +// to construct a new pointer. +type defaultAuthManager struct { + expirationTimeInMinutes time.Duration + authenticator Authenticator + jwtManager JWTManager + roleManager RoleManager + mfaManager MFAManager +} + +type AuthManagerOption func(am *defaultAuthManager) + +// NewAuthManager constructs a new `*AuthManager` and apply the options passed by parameter. +func NewAuthManager(options ...AuthManagerOption) AuthManager { + authManager := &defaultAuthManager{ + expirationTimeInMinutes: time.Minute * defaultExpirationTimeInMinutes, + } + + for _, option := range options { + option(authManager) + } + + return authManager +} + +// WithDefaultAuthenticatorOption sets a default authentication method that validates the users' credentials. +func WithDefaultAuthenticatorOption(dbConnectionPool db.DBConnectionPool, passwordEncrypter PasswordEncrypter, resetTokenExpirationHours time.Duration) AuthManagerOption { + return func(am *defaultAuthManager) { + am.authenticator = newDefaultAuthenticator( + withAuthenticatorDatabaseConnectionPool(dbConnectionPool), + withPasswordEncrypter(passwordEncrypter), + withResetTokenExpirationHours(resetTokenExpirationHours), + ) + } +} + +// WithDefaultAuthenticatorOption sets a custom authentication method that implements the `Authenticator` interface. +func WithCustomAuthenticatorOption(authenticator Authenticator) AuthManagerOption { + return func(am *defaultAuthManager) { + am.authenticator = authenticator + } +} + +// WithDefaultJWTManagerOption sets a default JWT Manager that generates, validates and refreshes the users' JWT token. +func WithDefaultJWTManagerOption(ECPublicKey, ECPrivateKey string) AuthManagerOption { + return func(am *defaultAuthManager) { + am.jwtManager = newDefaultJWTManager(withECKeypair(ECPublicKey, ECPrivateKey)) + } +} + +// WithDefaultJWTManagerOption sets a custom JWT Manager that implements the `JWTManager` interface. +func WithCustomJWTManagerOption(jwtManager JWTManager) AuthManagerOption { + return func(am *defaultAuthManager) { + am.jwtManager = jwtManager + } +} + +// WithExpirationTimeInMinutesOption sets the JWT token expiration time in minutes. Default is `5 minutes`. +func WithExpirationTimeInMinutesOption(minutes int) AuthManagerOption { + return func(am *defaultAuthManager) { + am.expirationTimeInMinutes = time.Minute * time.Duration(minutes) + } +} + +func WithDefaultRoleManagerOption(dbConnectionPool db.DBConnectionPool, ownerRoleName string) AuthManagerOption { + return func(am *defaultAuthManager) { + roleOptions := []defaultRoleManagerOption{ + withRoleManagerDBConnectionPool(dbConnectionPool), + } + + if ownerRoleName != "" { + roleOptions = append(roleOptions, withOwnerRoleName(ownerRoleName)) + } + + am.roleManager = newDefaultRoleManager(roleOptions...) + } +} + +func WithCustomRoleManagerOption(roleManager RoleManager) AuthManagerOption { + return func(am *defaultAuthManager) { + am.roleManager = roleManager + } +} + +func WithDefaultMFAManagerOption(dbConnectionPool db.DBConnectionPool) AuthManagerOption { + return func(am *defaultAuthManager) { + am.mfaManager = newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + } +} + +func WithCustomMFAManagerOption(mfaManager MFAManager) AuthManagerOption { + return func(am *defaultAuthManager) { + am.mfaManager = mfaManager + } +} diff --git a/stellar-auth/pkg/auth/manager_test.go b/stellar-auth/pkg/auth/manager_test.go new file mode 100644 index 000000000..7fad91a98 --- /dev/null +++ b/stellar-auth/pkg/auth/manager_test.go @@ -0,0 +1,32 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_User_Validate(t *testing.T) { + user := &User{ + ID: "", + FirstName: "", + LastName: "", + Email: "", + IsOwner: false, + Roles: []string{}, + } + + assert.EqualError(t, user.Validate(), "email is required") + + user.Email = "invalid" + assert.EqualError(t, user.Validate(), `email is invalid: the provided email "invalid" is not valid`) + + user.Email = "email@email.com" + assert.EqualError(t, user.Validate(), "first name is required") + + user.FirstName = "First" + assert.EqualError(t, user.Validate(), "last name is required") + + user.LastName = "Last" + assert.NoError(t, user.Validate()) +} diff --git a/stellar-auth/pkg/auth/mfa_manager.go b/stellar-auth/pkg/auth/mfa_manager.go new file mode 100644 index 000000000..5593b10ef --- /dev/null +++ b/stellar-auth/pkg/auth/mfa_manager.go @@ -0,0 +1,307 @@ +package auth + +import ( + "context" + "crypto/rand" + "database/sql" + "errors" + "fmt" + "math/big" + "time" + + "github.com/stellar/go/support/log" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" +) + +type MFAManager interface { + MFADeviceRemembered(ctx context.Context, deviceID, userID string) (bool, error) + GenerateMFACode(ctx context.Context, deviceID, userID string) (string, error) + ValidateMFACode(ctx context.Context, deviceID, code string) (string, error) + RememberDevice(ctx context.Context, deviceID, code string) error +} + +// defaultMFAManager +type defaultMFAManager struct { + dbConnectionPool db.DBConnectionPool +} + +const ( + mfaCodeMaxLength = 6 + mfaDeviceExpiryHours = time.Hour * 24 * 7 // 7 days + mfaCodeExpiryMinutes = time.Minute * 5 // 5 minutes +) + +var ( + ErrMFACodeInvalid = errors.New("MFA code is invalid") + ErrMFANoCodeForUserDevice = errors.New("no MFA code for user and device") +) + +type mfaCode struct { + DeviceID string `db:"device_id"` + UserID string `db:"auth_user_id"` + Code string `db:"code"` + DeviceExpiresAt *time.Time `db:"device_expires_at"` + CodeExpiresAt *time.Time `db:"code_expires_at"` +} + +// MFADeviceRemembered checks if the device is remembered for the user. +func (m *defaultMFAManager) MFADeviceRemembered(ctx context.Context, deviceID, userID string) (bool, error) { + mc, err := m.getByDeviceAndUser(ctx, deviceID, userID) + if err != nil { + if errors.Is(err, ErrMFANoCodeForUserDevice) { + return false, nil + } + return false, fmt.Errorf("error validating MFA device for token string %s and device ID %s: %w", userID, deviceID, err) + } + + // 1. Device Exists: ❌ | Device Valid: – | + // 2. Device Exists: βœ… | Device Valid: – | + // 3. Device Exists: βœ… | Device Valid: ❌ + if mc == nil || + mc.DeviceExpiresAt == nil || + (mc.DeviceExpiresAt != nil && mc.DeviceExpiresAt.Before(time.Now())) { + return false, nil + } + + // 4. Device Exists: βœ… | Device Valid: βœ… + return true, nil +} + +// GenerateMFACode generates a new MFA code for the user and device. +func (m *defaultMFAManager) GenerateMFACode(ctx context.Context, deviceID, userID string) (string, error) { + mc, err := m.getByDeviceAndUser(ctx, deviceID, userID) + if err != nil && !errors.Is(err, ErrMFANoCodeForUserDevice) { + return "", fmt.Errorf("error validating MFA device for user ID %s and device ID %s: %w", userID, deviceID, err) + } + + // 1. Device Exists: ❌ | Code Exists: - | Code Valid: - + // 2. Device Exists: βœ… | Code Exists: ❌ | Code Valid: - + // 3. Device Exists: βœ… | Code Exists: βœ… | Code Valid: ❌ + // β€· Persist & send new code + if mc == nil || mc.Code == "" || (mc.CodeExpiresAt != nil && mc.CodeExpiresAt.Before(time.Now())) { + return m.generateAndUpdateMFACode(ctx, deviceID, userID) + } + + // 4. Device Exists: βœ… | Code Exists: βœ… | Code Valid: βœ… + // β€· Explicitly expire the old code and generate a new one. + if mc.CodeExpiresAt != nil && mc.CodeExpiresAt.After(time.Now()) { + log.Ctx(ctx).Infof("expiring a valid MFA code for device ID %s and user ID %s", deviceID, userID) + err = m.expireMFACode(ctx, deviceID, mc.Code) + if err != nil { + return "", fmt.Errorf("expiring MFA code for device ID %s and code %s: %w", deviceID, mc.Code, err) + } + return m.generateAndUpdateMFACode(ctx, deviceID, userID) + } + + return "", nil +} + +// ValidateMFACode checks if the MFA code is valid for the device ID and returns the user ID. +func (m *defaultMFAManager) ValidateMFACode(ctx context.Context, deviceID, code string) (string, error) { + return db.RunInTransactionWithResult(ctx, m.dbConnectionPool, nil, func(dbTx db.DBTransaction) (string, error) { + mc, err := m.getByDeviceAndCode(ctx, deviceID, code) + if err != nil { + if errors.Is(err, ErrMFANoCodeForUserDevice) { + return "", ErrMFACodeInvalid + } + return "", fmt.Errorf("error validating MFA code for device ID %s: %w", deviceID, err) + } + + if mc != nil && mc.Code == code && mc.CodeExpiresAt != nil && mc.CodeExpiresAt.After(time.Now()) { + err = m.expireMFACode(ctx, deviceID, code) + if err != nil { + return "", fmt.Errorf("error expiring MFA code for device ID %s and code %s: %w", deviceID, code, err) + } + return mc.UserID, nil + } + + return "", ErrMFACodeInvalid + }) +} + +// RememberDevice updates the device expiry for the device. +func (m *defaultMFAManager) RememberDevice(ctx context.Context, deviceID, code string) error { + err := m.resetDeviceExpiry(ctx, deviceID, code) + if err != nil { + return fmt.Errorf("error updating device expiry for device ID %s and code %s: %w", deviceID, code, err) + } + return nil +} + +// ForgetDevice expires the device for the user. +func (m *defaultMFAManager) ForgetDevice(ctx context.Context, deviceID, userID string) error { + if deviceID == "" || userID == "" { + return fmt.Errorf("device ID and user ID are required") + } + + const query = ` + UPDATE auth_user_mfa_codes + SET device_expires_at = null + WHERE device_id = $1 AND auth_user_id = $2 + ` + _, err := m.dbConnectionPool.ExecContext(ctx, query, deviceID, userID) + if err != nil { + return fmt.Errorf("error expiring device for device ID %s and user ID %s: %w", deviceID, userID, err) + } + return nil +} + +// getByDeviceAndUser gets the MFA code for the user and device. +func (m *defaultMFAManager) getByDeviceAndUser(ctx context.Context, deviceID, userID string) (*mfaCode, error) { + if deviceID == "" || userID == "" { + return nil, fmt.Errorf("device ID and user ID are required") + } + const query = ` + SELECT + device_id, + auth_user_id, + COALESCE(code, '') AS code, + device_expires_at, + code_expires_at + FROM + auth_user_mfa_codes + WHERE + device_id = $1 AND + auth_user_id = $2 + ` + var mc mfaCode + err := m.dbConnectionPool.GetContext(ctx, &mc, query, deviceID, userID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrMFANoCodeForUserDevice + } + return nil, fmt.Errorf("error fetching MFA code for device ID %s and user ID %s: %w", deviceID, userID, err) + } + + return &mc, nil +} + +// getByDeviceAndCode gets the MFA code for the device and code. +func (m *defaultMFAManager) getByDeviceAndCode(ctx context.Context, deviceID, code string) (*mfaCode, error) { + if deviceID == "" || code == "" { + return nil, fmt.Errorf("device ID and code are required") + } + const query = ` + SELECT + device_id, + auth_user_id, + COALESCE(code, '') AS code, + device_expires_at, + code_expires_at + FROM + auth_user_mfa_codes + WHERE + device_id = $1 AND + code = $2 + ` + var mc mfaCode + err := m.dbConnectionPool.GetContext(ctx, &mc, query, deviceID, code) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrMFANoCodeForUserDevice + } + return nil, fmt.Errorf("error fetching MFA code for device ID %s: %w", deviceID, err) + } + + return &mc, nil +} + +// generateAndUpdateMFACode generates a new MFA code and upserts it for the user and device. +func (m *defaultMFAManager) generateAndUpdateMFACode(ctx context.Context, deviceID, userID string) (string, error) { + code, err := generateMFACode() + if err != nil { + return "", fmt.Errorf("error generating MFA code for user ID %s and device ID %s: %w", userID, deviceID, err) + } + err = m.upsertMFACode(ctx, deviceID, userID, code) + if err != nil { + return "", fmt.Errorf("error updating MFA code for user ID %s and device ID %s: %w", userID, deviceID, err) + } + return code, nil +} + +// upsertMFACode upserts the MFA code for the user and device. +func (m *defaultMFAManager) upsertMFACode(ctx context.Context, deviceID, userID, code string) error { + if deviceID == "" || userID == "" || code == "" { + return fmt.Errorf("device ID, user ID and code are required") + } + const query = ` + INSERT INTO auth_user_mfa_codes (auth_user_id, device_id, code, code_expires_at) + VALUES ($1, $2, $3, $4) + ON CONFLICT (auth_user_id, device_id) + DO UPDATE SET code = $3, code_expires_at = $4 + ` + _, err := m.dbConnectionPool.ExecContext(ctx, query, userID, deviceID, code, time.Now().Add(mfaCodeExpiryMinutes)) + if err != nil { + return fmt.Errorf("error upserting MFA code for user ID %s and device ID %s: %w", userID, deviceID, err) + } + return nil +} + +// resetDeviceExpiry resets the device expiry for the user and device. +func (m *defaultMFAManager) resetDeviceExpiry(ctx context.Context, deviceID, code string) error { + if deviceID == "" || code == "" { + return fmt.Errorf("device ID and code are required") + } + const query = ` + UPDATE auth_user_mfa_codes + SET device_expires_at = $1 + WHERE device_id = $2 AND code = $3 + ` + _, err := m.dbConnectionPool.ExecContext(ctx, query, time.Now().Add(mfaDeviceExpiryHours), deviceID, code) + if err != nil { + return fmt.Errorf("error updating device expiry for device ID %s and code %s: %w", deviceID, code, err) + } + return nil +} + +// expireMFACode expires the MFA code for the user and device. +func (m *defaultMFAManager) expireMFACode(ctx context.Context, deviceID, code string) error { + if deviceID == "" || code == "" { + return fmt.Errorf("device ID and code are required") + } + const query = ` + UPDATE auth_user_mfa_codes + SET code = null, code_expires_at = null + WHERE device_id = $1 AND code = $2 + ` + _, err := m.dbConnectionPool.ExecContext(ctx, query, deviceID, code) + if err != nil { + return fmt.Errorf("error expiring MFA code for device ID %s and code %s: %w", deviceID, code, err) + } + return nil +} + +// generateMFACode generate a random 6-digit MFA code. +func generateMFACode() (string, error) { + code := "" + for i := 0; i < mfaCodeMaxLength; i++ { + randomDigit, err := rand.Int(rand.Reader, big.NewInt(10)) + if err != nil { + return "", fmt.Errorf("error generating random digit for MFA code: %w", err) + } + code += fmt.Sprintf("%d", randomDigit) + } + return code, nil +} + +type defaultMFAManagerOption func(m *defaultMFAManager) + +func newDefaultMFAManager(options ...defaultMFAManagerOption) *defaultMFAManager { + mfaManager := &defaultMFAManager{} + + for _, option := range options { + option(mfaManager) + } + + return mfaManager +} + +func withMFADatabaseConnectionPool(dbConnectionPool db.DBConnectionPool) defaultMFAManagerOption { + return func(a *defaultMFAManager) { + a.dbConnectionPool = dbConnectionPool + } +} + +// Ensuring that defaultMFAManager is implementing MFAManager interface +var _ MFAManager = (*defaultMFAManager)(nil) diff --git a/stellar-auth/pkg/auth/mfa_manager_test.go b/stellar-auth/pkg/auth/mfa_manager_test.go new file mode 100644 index 000000000..a41dfbf47 --- /dev/null +++ b/stellar-auth/pkg/auth/mfa_manager_test.go @@ -0,0 +1,567 @@ +package auth + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_defaultMFAManager_MFADeviceRemembered(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or userID is empty", func(t *testing.T) { + _, err := m.MFADeviceRemembered(ctx, "", "") + require.ErrorContains(t, err, "device ID and user ID are required") + _, err = m.MFADeviceRemembered(ctx, "deviceID", "") + require.ErrorContains(t, err, "device ID and user ID are required") + _, err = m.MFADeviceRemembered(ctx, "", "userID") + require.ErrorContains(t, err, "device ID and user ID are required") + }) + + t.Run("Test error when user not found", func(t *testing.T) { + isRemembered, err := m.MFADeviceRemembered(ctx, "deviceID", "nonExistentUser") + require.NoError(t, err) + require.False(t, isRemembered) + }) + + t.Run("Test Device Exists: ❌ | Device Valid: – |", func(t *testing.T) { + isRemembered, err := m.MFADeviceRemembered(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.False(t, isRemembered) + }) + + t.Run("Test Device Exists: βœ… | Device Valid: ❌ |", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + // Generate code for device and expire device + _, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + err = m.ForgetDevice(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + + isValid, err := m.MFADeviceRemembered(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.False(t, isValid) + }) + + t.Run("Test Device Exists: βœ… | Device Valid: βœ… |", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + // Generate code for device and remember device + code, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + err = m.RememberDevice(ctx, "deviceID", code) + require.NoError(t, err) + + // Validate device + isRemembered, err := m.MFADeviceRemembered(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.True(t, isRemembered) + }) +} + +func Test_defaultMFAManager_GenerateMFACode(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or userID is empty", func(t *testing.T) { + _, err := m.GenerateMFACode(ctx, "", "") + require.ErrorContains(t, err, "device ID and user ID are required") + _, err = m.GenerateMFACode(ctx, "deviceID", "") + require.ErrorContains(t, err, "device ID and user ID are required") + _, err = m.GenerateMFACode(ctx, "", "userID") + require.ErrorContains(t, err, "device ID and user ID are required") + }) + + t.Run("Test error when user not found", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + _, err := m.GenerateMFACode(ctx, "deviceID", "nonExistentUser") + require.ErrorContains(t, err, "error updating MFA code for user ID nonExistentUser and device ID deviceID") + }) + + t.Run("Test Device Exists: ❌ | Code Exists: - | Code Valid: -", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + code, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.NotNil(t, code) + require.Equal(t, 6, len(code)) + + mc, err := m.getByDeviceAndCode(ctx, "deviceID", code) + require.NoError(t, err) + require.NotNil(t, mc) + require.Equal(t, code, mc.Code) + require.Equal(t, "deviceID", mc.DeviceID) + require.Equal(t, randUser.ID, mc.UserID) + require.Nil(t, mc.DeviceExpiresAt) + require.True(t, mc.CodeExpiresAt.After(time.Now().Add(mfaCodeExpiryMinutes).Add(-time.Minute))) + }) + + t.Run("Test Device Exists: βœ… | Code Exists: ❌ | Code Valid: -", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + // Insert entry for `deviceID` and `randUser.ID` + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, auth_user_id, device_expires_at) + VALUES ($1, $2, NOW() + INTERVAL '1 hour')`, "deviceID", randUser.ID) + require.NoError(t, err) + + // Generate new code for `deviceID` and `randUser.ID` + code, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.NotNil(t, code) + require.Equal(t, 6, len(code)) + + mc, err := m.getByDeviceAndCode(ctx, "deviceID", code) + require.NoError(t, err) + require.NotNil(t, mc) + require.Equal(t, code, mc.Code) + require.Equal(t, "deviceID", mc.DeviceID) + require.Equal(t, randUser.ID, mc.UserID) + require.True(t, mc.CodeExpiresAt.After(time.Now().Add(mfaCodeExpiryMinutes).Add(-time.Minute))) + }) + + t.Run("Test Device Exists: βœ… | Code Exists: βœ… | Code Valid: ❌", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + // Generate code and expire it + expiredCode, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + _, err = dbConnectionPool.ExecContext(ctx, ` + UPDATE auth_user_mfa_codes SET code_expires_at = NOW() - INTERVAL '1 hour' + WHERE device_id = $1 AND auth_user_id = $2`, "deviceID", randUser.ID) + require.NoError(t, err) + + // Generate new code for `deviceID` and `randUser.ID` + code, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.NotNil(t, code) + require.Equal(t, 6, len(code)) + require.NotEqual(t, expiredCode, code) + + mc, err := m.getByDeviceAndCode(ctx, "deviceID", code) + require.NoError(t, err) + require.NotNil(t, mc) + require.Equal(t, code, mc.Code) + require.Equal(t, "deviceID", mc.DeviceID) + require.Equal(t, randUser.ID, mc.UserID) + require.Nil(t, mc.DeviceExpiresAt) + require.True(t, mc.CodeExpiresAt.After(time.Now().Add(mfaCodeExpiryMinutes).Add(-time.Minute))) + }) + + t.Run("Test code expired and re-generated when valid one exists", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + // Generate code + code, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.NotNil(t, code) + require.Equal(t, 6, len(code)) + + // Try generating another one + newCode, err := m.GenerateMFACode(ctx, "deviceID", randUser.ID) + require.NoError(t, err) + require.NotEqual(t, newCode, code) + }) +} + +func Test_defaultMFAManager_ValidateMFACode(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or code is empty", func(t *testing.T) { + _, err := m.ValidateMFACode(ctx, "", "") + require.ErrorContains(t, err, "device ID and code are required") + _, err = m.ValidateMFACode(ctx, "deviceID", "") + require.ErrorContains(t, err, "device ID and code are required") + _, err = m.ValidateMFACode(ctx, "", "code") + require.ErrorContains(t, err, "device ID and code are required") + }) + + t.Run("Test MFA code validation", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, device_expires_at, code_expires_at) + VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour', NOW() + INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + // Test MFA code validation + userID, err := m.ValidateMFACode(ctx, testDeviceID, testCode) + assert.NoError(t, err) + assert.Equal(t, randUser.ID, userID) + }) + + t.Run("Test invalid MFA code", func(t *testing.T) { + testDeviceID := "anotherDeviceID" + testCode := "222333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, device_expires_at, code_expires_at) + VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour', NOW() - INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + _, err = m.ValidateMFACode(ctx, testDeviceID, testCode) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrMFACodeInvalid)) + }) +} + +func Test_defaultMFAManager_RememberDevice(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or code is empty", func(t *testing.T) { + err := m.RememberDevice(ctx, "", "") + require.ErrorContains(t, err, "device ID and code are required") + err = m.RememberDevice(ctx, "deviceID", "") + require.ErrorContains(t, err, "device ID and code are required") + err = m.RememberDevice(ctx, "", "code") + require.ErrorContains(t, err, "device ID and code are required") + }) + + t.Run("Test updating device expiry", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, device_expires_at, code_expires_at) + VALUES ($1, $2, $3, NOW() - INTERVAL '1 hour', NOW() + INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + err = m.RememberDevice(ctx, testDeviceID, testCode) + require.NoError(t, err) + + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.True(t, mc.DeviceExpiresAt.After(time.Now())) + }) +} + +func Test_defaultMFAManager_ForgetDevice(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or code is empty", func(t *testing.T) { + err := m.ForgetDevice(ctx, "", "") + require.EqualError(t, err, "device ID and user ID are required") + err = m.ForgetDevice(ctx, "deviceID", "") + require.EqualError(t, err, "device ID and user ID are required") + err = m.ForgetDevice(ctx, "", "code") + require.EqualError(t, err, "device ID and user ID are required") + }) + + t.Run("Test forget device", func(t *testing.T) { + defer cleanup(t, ctx, dbConnectionPool) + + testDeviceID := "testDeviceID" + + // Generate code and remember device + code, err := m.GenerateMFACode(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.Equal(t, 6, len(code)) + + err = m.RememberDevice(ctx, testDeviceID, code) + require.NoError(t, err) + + // Fetch entry and check that device is remembered + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.NotNil(t, mc) + require.True(t, mc.DeviceExpiresAt.After(time.Now())) + + // Forget device + err = m.ForgetDevice(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + + // Fetch entry and check that device is forgotten + mc, err = m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.NotNil(t, mc) + require.Nil(t, mc.DeviceExpiresAt) + }) +} + +func Test_defaultMFAManager_getByDeviceAndCode(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or code is empty", func(t *testing.T) { + _, err := m.getByDeviceAndCode(ctx, "", "") + require.EqualError(t, err, "device ID and code are required") + _, err = m.getByDeviceAndCode(ctx, "deviceID", "") + require.EqualError(t, err, "device ID and code are required") + _, err = m.getByDeviceAndCode(ctx, "", "code") + require.EqualError(t, err, "device ID and code are required") + }) + + t.Run("Test fetching MFA code by device and code", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, code_expires_at) + VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + mc, err := m.getByDeviceAndCode(ctx, testDeviceID, testCode) + require.NoError(t, err) + require.NotNil(t, mc) + require.Equal(t, testCode, mc.Code) + require.Equal(t, testDeviceID, mc.DeviceID) + require.Equal(t, randUser.ID, mc.UserID) + require.Nil(t, mc.DeviceExpiresAt) + require.True(t, mc.CodeExpiresAt.After(time.Now().Add(mfaCodeExpiryMinutes).Add(-time.Minute))) + }) + + t.Run("Test fetching non-existent MFA code", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "nonExistentCode" + + // Test fetching MFA code + _, err := m.getByDeviceAndCode(ctx, testDeviceID, testCode) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrMFANoCodeForUserDevice)) + }) +} + +func Test_defaultMFAManager_generateAndUpdateMFACode(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test generate and upsert new MFA code", func(t *testing.T) { + testDeviceID := "testDeviceID" + + generatedCode, err := m.generateAndUpdateMFACode(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.NotEmpty(t, generatedCode) + + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.Equal(t, generatedCode, mc.Code) + require.Equal(t, testDeviceID, mc.DeviceID) + require.Equal(t, randUser.ID, mc.UserID) + require.Nil(t, mc.DeviceExpiresAt) + require.True(t, mc.CodeExpiresAt.After(time.Now().Add(mfaCodeExpiryMinutes).Add(-time.Minute))) + }) +} + +func Test_defaultMFAManager_upsertMFACode(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test upsert new MFA code", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + + // Test inserting new MFA code + err := m.upsertMFACode(ctx, testDeviceID, randUser.ID, testCode) + assert.NoError(t, err) + + // Check that the record was inserted correctly + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + assert.Equal(t, testCode, mc.Code) + + // Cleanup: Delete the test record + _, err = dbConnectionPool.ExecContext(ctx, ` + DELETE FROM auth_user_mfa_codes WHERE device_id = $1 AND auth_user_id = $2`, testDeviceID, randUser.ID) + require.NoError(t, err) + }) + + t.Run("Test update existing MFA code", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, code_expires_at) + VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + // Test updating existing MFA code + newCode := "222444" + err = m.upsertMFACode(ctx, testDeviceID, randUser.ID, newCode) + assert.NoError(t, err) + + // Check that the record was updated correctly + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + assert.Equal(t, newCode, mc.Code) + + // Cleanup: Delete the test record + _, err = dbConnectionPool.ExecContext(ctx, ` + DELETE FROM auth_user_mfa_codes WHERE device_id = $1 AND auth_user_id = $2`, testDeviceID, randUser.ID) + require.NoError(t, err) + }) +} + +func Test_defaultMFAManager_resetDeviceExpiry(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or code is empty", func(t *testing.T) { + err := m.resetDeviceExpiry(ctx, "", "") + assert.EqualError(t, err, "device ID and code are required") + err = m.resetDeviceExpiry(ctx, "deviceID", "") + assert.EqualError(t, err, "device ID and code are required") + err = m.resetDeviceExpiry(ctx, "", "code") + assert.EqualError(t, err, "device ID and code are required") + }) + + t.Run("Test device expiry reset", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, code_expires_at) + VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + err = m.resetDeviceExpiry(ctx, testDeviceID, testCode) + assert.NoError(t, err) + + // Check that the record was updated correctly + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.True(t, mc.DeviceExpiresAt.After(time.Now().Add(mfaDeviceExpiryHours).Add(-time.Minute))) + }) +} + +func Test_defaultMFAManager_expireMFACode(t *testing.T) { + ctx := context.Background() + + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + randUser := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, NewDefaultPasswordEncrypter(), false) + + m := newDefaultMFAManager(withMFADatabaseConnectionPool(dbConnectionPool)) + + t.Run("Test error when deviceID or code is empty", func(t *testing.T) { + err := m.expireMFACode(ctx, "", "") + assert.EqualError(t, err, "device ID and code are required") + err = m.expireMFACode(ctx, "deviceID", "") + assert.EqualError(t, err, "device ID and code are required") + err = m.expireMFACode(ctx, "", "code") + assert.EqualError(t, err, "device ID and code are required") + }) + + t.Run("Test entry not found", func(t *testing.T) { + testDeviceID := "testDeviceID" + testCode := "111333" + _, err := dbConnectionPool.ExecContext(ctx, ` + INSERT INTO auth_user_mfa_codes (device_id, code, auth_user_id, code_expires_at) + VALUES ($1, $2, $3, NOW() + INTERVAL '1 hour')`, testDeviceID, testCode, randUser.ID) + require.NoError(t, err) + + err = m.expireMFACode(ctx, testDeviceID, testCode) + assert.NoError(t, err) + + // Check that the record was updated correctly + mc, err := m.getByDeviceAndUser(ctx, testDeviceID, randUser.ID) + require.NoError(t, err) + require.Nil(t, mc.CodeExpiresAt) + require.Equal(t, "", mc.Code) + }) +} + +func Test_defaultMFAManager_generateMFACode(t *testing.T) { + code, err := generateMFACode() + assert.NoError(t, err) + assert.Equal(t, 6, len(code)) + for _, c := range code { + assert.True(t, c >= '0' && c <= '9') + } +} + +func cleanup(t *testing.T, ctx context.Context, dbConnectionPool db.DBConnectionPool) { + _, err := dbConnectionPool.ExecContext(ctx, "DELETE FROM auth_user_mfa_codes") + require.NoError(t, err) +} diff --git a/stellar-auth/pkg/auth/mocks.go b/stellar-auth/pkg/auth/mocks.go new file mode 100644 index 000000000..95b1f859a --- /dev/null +++ b/stellar-auth/pkg/auth/mocks.go @@ -0,0 +1,291 @@ +package auth + +import ( + "context" + "time" + + "github.com/stretchr/testify/mock" +) + +// PasswordEncrypter +type PasswordEncrypterMock struct { + mock.Mock +} + +func (em *PasswordEncrypterMock) Encrypt(ctx context.Context, password string) (string, error) { + args := em.Called(ctx, password) + return args.Get(0).(string), args.Error(1) +} + +func (em *PasswordEncrypterMock) ComparePassword(ctx context.Context, encryptedPassword, password string) (bool, error) { + args := em.Called(ctx, encryptedPassword, password) + return args.Get(0).(bool), args.Error(1) +} + +var _ PasswordEncrypter = (*PasswordEncrypterMock)(nil) + +// JWTManager +type JWTManagerMock struct { + mock.Mock +} + +func (m *JWTManagerMock) GenerateToken(ctx context.Context, user *User, expiresAt time.Time) (string, error) { + args := m.Called(ctx, user, expiresAt) + return args.Get(0).(string), args.Error(1) +} + +func (m *JWTManagerMock) RefreshToken(ctx context.Context, token string, expiresAt time.Time) (string, error) { + args := m.Called(ctx, token, expiresAt) + return args.Get(0).(string), args.Error(1) +} + +func (m *JWTManagerMock) ValidateToken(ctx context.Context, token string) (bool, error) { + args := m.Called(ctx, token) + return args.Get(0).(bool), args.Error(1) +} + +func (m *JWTManagerMock) GetUserFromToken(ctx context.Context, tokenString string) (*User, error) { + args := m.Called(ctx, tokenString) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*User), args.Error(1) +} + +var _ JWTManager = (*JWTManagerMock)(nil) + +// Authenticator +type AuthenticatorMock struct { + mock.Mock +} + +func (am *AuthenticatorMock) ValidateCredentials(ctx context.Context, email, password string) (*User, error) { + args := am.Called(ctx, email, password) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*User), args.Error(1) +} + +func (am *AuthenticatorMock) CreateUser(ctx context.Context, user *User, password string) (*User, error) { + args := am.Called(ctx, user, password) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*User), args.Error(1) +} + +func (am *AuthenticatorMock) UpdateUser(ctx context.Context, ID, firstName, lastName, email, password string) error { + args := am.Called(ctx, ID, firstName, lastName, email, password) + return args.Error(0) +} + +func (am *AuthenticatorMock) ActivateUser(ctx context.Context, userID string) error { + args := am.Called(ctx, userID) + return args.Error(0) +} + +func (am *AuthenticatorMock) DeactivateUser(ctx context.Context, userID string) error { + args := am.Called(ctx, userID) + return args.Error(0) +} + +func (am *AuthenticatorMock) ForgotPassword(ctx context.Context, email string) (string, error) { + args := am.Called(ctx, email) + return args.Get(0).(string), args.Error(1) +} + +func (am *AuthenticatorMock) ResetPassword(ctx context.Context, resetToken, password string) error { + args := am.Called(ctx, resetToken, password) + return args.Error(0) +} + +func (am *AuthenticatorMock) GetAllUsers(ctx context.Context) ([]User, error) { + args := am.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]User), args.Error(1) +} + +func (am *AuthenticatorMock) GetUser(ctx context.Context, userID string) (*User, error) { + args := am.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*User), args.Error(1) +} + +var _ Authenticator = (*AuthenticatorMock)(nil) + +type RoleManagerMock struct { + mock.Mock +} + +func (rm *RoleManagerMock) GetUserRoles(ctx context.Context, user *User) ([]string, error) { + args := rm.Called(ctx, user) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) +} + +func (rm *RoleManagerMock) HasAllRoles(ctx context.Context, user *User, roleNames []string) (bool, error) { + args := rm.Called(ctx, user, roleNames) + return args.Get(0).(bool), args.Error(1) +} + +func (rm *RoleManagerMock) HasAnyRoles(ctx context.Context, user *User, roleNames []string) (bool, error) { + args := rm.Called(ctx, user, roleNames) + return args.Get(0).(bool), args.Error(1) +} + +func (rm *RoleManagerMock) IsSuperUser(ctx context.Context, user *User) (bool, error) { + args := rm.Called(ctx, user) + return args.Get(0).(bool), args.Error(1) +} + +func (rm *RoleManagerMock) UpdateRoles(ctx context.Context, user *User, roleNames []string) error { + args := rm.Called(ctx, user, roleNames) + return args.Error(0) +} + +var _ RoleManager = (*RoleManagerMock)(nil) + +// MFAManager +type MFAManagerMock struct { + mock.Mock +} + +func (m *MFAManagerMock) MFADeviceRemembered(ctx context.Context, deviceID, userID string) (bool, error) { + args := m.Called(ctx, deviceID, userID) + return args.Get(0).(bool), args.Error(1) +} + +func (m *MFAManagerMock) GenerateMFACode(ctx context.Context, deviceID, userID string) (string, error) { + args := m.Called(ctx, deviceID, userID) + return args.Get(0).(string), args.Error(1) +} + +func (m *MFAManagerMock) ValidateMFACode(ctx context.Context, deviceID, code string) (string, error) { + args := m.Called(ctx, deviceID, code) + return args.Get(0).(string), args.Error(1) +} + +func (m *MFAManagerMock) RememberDevice(ctx context.Context, deviceID, code string) error { + args := m.Called(ctx, deviceID, code) + return args.Error(0) +} + +var _ MFAManager = (*MFAManagerMock)(nil) + +// AuthManager +type AuthManagerMock struct { + mock.Mock +} + +func (am *AuthManagerMock) Authenticate(ctx context.Context, email, pass string) (string, error) { + args := am.Called(ctx, email, pass) + return args.Get(0).(string), args.Error(1) +} + +func (am *AuthManagerMock) RefreshToken(ctx context.Context, tokenString string) (string, error) { + args := am.Called(ctx, tokenString) + return args.Get(0).(string), args.Error(1) +} + +func (am *AuthManagerMock) ValidateToken(ctx context.Context, tokenString string) (bool, error) { + args := am.Called(ctx, tokenString) + return args.Get(0).(bool), args.Error(1) +} + +func (am *AuthManagerMock) AllRolesInTokenUser(ctx context.Context, tokenString string, roleNames []string) (bool, error) { + args := am.Called(ctx, tokenString, roleNames) + return args.Get(0).(bool), args.Error(1) +} + +func (am *AuthManagerMock) AnyRolesInTokenUser(ctx context.Context, tokenString string, roleNames []string) (bool, error) { + args := am.Called(ctx, tokenString, roleNames) + return args.Get(0).(bool), args.Error(1) +} + +func (am *AuthManagerMock) CreateUser(ctx context.Context, user *User, password string) (*User, error) { + args := am.Called(ctx, user, password) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*User), args.Error(1) +} + +func (am *AuthManagerMock) UpdateUser(ctx context.Context, tokenString, firstName, lastName, email, password string) error { + args := am.Called(ctx, tokenString, firstName, lastName, email, password) + return args.Error(0) +} + +func (am *AuthManagerMock) ForgotPassword(ctx context.Context, email string) (string, error) { + args := am.Called(ctx, email) + return args.Get(0).(string), args.Error(1) +} + +func (am *AuthManagerMock) ResetPassword(ctx context.Context, tokenString, password string) error { + args := am.Called(ctx, tokenString, password) + return args.Error(0) +} + +func (am *AuthManagerMock) GetUser(ctx context.Context, tokenString string) (*User, error) { + args := am.Called(ctx, tokenString) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*User), args.Error(1) +} + +func (am *AuthManagerMock) GetAllUsers(ctx context.Context, tokenString string) ([]User, error) { + args := am.Called(ctx, tokenString) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]User), args.Error(1) +} + +func (am *AuthManagerMock) UpdateUserRoles(ctx context.Context, tokenString, userID string, roles []string) error { + args := am.Called(ctx, tokenString, userID, roles) + return args.Error(0) +} + +func (am *AuthManagerMock) DeactivateUser(ctx context.Context, tokenString, userID string) error { + args := am.Called(ctx, tokenString, userID) + return args.Error(0) +} + +func (am *AuthManagerMock) ActivateUser(ctx context.Context, tokenString, userID string) error { + args := am.Called(ctx, tokenString, userID) + return args.Error(0) +} + +func (am *AuthManagerMock) ExpirationTimeInMinutes() time.Duration { + args := am.Called() + return args.Get(0).(time.Duration) +} + +func (am *AuthManagerMock) MFADeviceRemembered(ctx context.Context, userID, deviceID string) (bool, error) { + args := am.Called(ctx, userID, deviceID) + return args.Get(0).(bool), args.Error(1) +} + +func (am *AuthManagerMock) GetMFACode(ctx context.Context, userID, deviceID string) (string, error) { + args := am.Called(ctx, userID, deviceID) + return args.Get(0).(string), args.Error(1) +} + +func (am *AuthManagerMock) GenerateMFACode(ctx context.Context, userID, deviceID string) (string, error) { + args := am.Called(ctx, userID, deviceID) + return args.Get(0).(string), args.Error(1) +} + +func (am *AuthManagerMock) AuthenticateMFA(ctx context.Context, deviceID, code string, rememberMe bool) (string, error) { + args := am.Called(ctx, deviceID, code, rememberMe) + return args.Get(0).(string), args.Error(1) +} + +var _ AuthManager = (*AuthManagerMock)(nil) diff --git a/stellar-auth/pkg/auth/password_encrypter.go b/stellar-auth/pkg/auth/password_encrypter.go new file mode 100644 index 000000000..de9359dda --- /dev/null +++ b/stellar-auth/pkg/auth/password_encrypter.go @@ -0,0 +1,58 @@ +package auth + +import ( + "context" + "errors" + "fmt" + + "golang.org/x/crypto/bcrypt" +) + +const ( + minPasswordLength = 8 + maxPasswordLength = 16 +) + +var ErrPasswordTooShort = errors.New("password should have at least 8 characters") + +// PasswordEncrypter is a interface that defines the methods to encrypt passwords and compare a password with its stored hash. +// This interface is used by `DefaultAuthenticator` as the type of `passwordEncrypter` attribute. +type PasswordEncrypter interface { + // Encrypt encrypts the `password` and return a hash. + Encrypt(ctx context.Context, password string) (string, error) + + // ComparePassword compares the `encryptedPassword` with the plain `password` to verify if it's correct. + ComparePassword(ctx context.Context, encryptedPassword, password string) (bool, error) +} + +// DefaultPasswordEncrypter defines the default way of encrypting passwords and comparing passwords with its stored hash. +// It uses `bcrypt` library to handle with the encryption and comparison. +type DefaultPasswordEncrypter struct{} + +func (e *DefaultPasswordEncrypter) Encrypt(ctx context.Context, password string) (string, error) { + // Assumes that a password can't have less than 8 characters. + if len(password) < minPasswordLength { + return "", ErrPasswordTooShort + } + + encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("encrypting password: %w", err) + } + + return string(encryptedPassword), nil +} + +func (e *DefaultPasswordEncrypter) ComparePassword(ctx context.Context, encryptedPassword, password string) (bool, error) { + err := bcrypt.CompareHashAndPassword([]byte(encryptedPassword), []byte(password)) + if err != nil && !errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return false, fmt.Errorf("comparing encrypted password and password: %w", err) + } + return err == nil, nil +} + +func NewDefaultPasswordEncrypter() *DefaultPasswordEncrypter { + return &DefaultPasswordEncrypter{} +} + +var _ PasswordEncrypter = (*DefaultPasswordEncrypter)(nil) diff --git a/stellar-auth/pkg/auth/password_encrypter_test.go b/stellar-auth/pkg/auth/password_encrypter_test.go new file mode 100644 index 000000000..1fd215a6f --- /dev/null +++ b/stellar-auth/pkg/auth/password_encrypter_test.go @@ -0,0 +1,81 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DefaultPasswordEncrypter_Encrypt(t *testing.T) { + passwordEncrypter := NewDefaultPasswordEncrypter() + + ctx := context.Background() + + t.Run("returns err when password is too short", func(t *testing.T) { + password := "" + + encryptedPassword, err := passwordEncrypter.Encrypt(ctx, password) + + assert.EqualError(t, err, ErrPasswordTooShort.Error()) + assert.Empty(t, encryptedPassword) + + password = "secret" + + encryptedPassword, err = passwordEncrypter.Encrypt(ctx, password) + + assert.EqualError(t, err, ErrPasswordTooShort.Error()) + assert.Empty(t, encryptedPassword) + }) + + t.Run("encrypts the password correctly", func(t *testing.T) { + password := "mysecret" + + encryptedPassword, err := passwordEncrypter.Encrypt(ctx, password) + require.NoError(t, err) + + assert.NotEmpty(t, encryptedPassword) + assert.NotEqual(t, password, encryptedPassword) + assert.Len(t, encryptedPassword, 60) + + password = "myanothersecret" + + encryptedPassword, err = passwordEncrypter.Encrypt(ctx, password) + require.NoError(t, err) + + assert.NotEmpty(t, encryptedPassword) + assert.NotEqual(t, password, encryptedPassword) + assert.Len(t, encryptedPassword, 60) + }) +} + +func Test_DefaultPasswordEncrypter_ComparePassword(t *testing.T) { + passwordEncrypter := NewDefaultPasswordEncrypter() + + ctx := context.Background() + + t.Run("returns false when the password is wrong", func(t *testing.T) { + password := "mysecret" + + encryptedPassword, err := passwordEncrypter.Encrypt(ctx, password) + require.NoError(t, err) + + isEqual, err := passwordEncrypter.ComparePassword(ctx, encryptedPassword, "wrongsecret") + require.NoError(t, err) + + assert.False(t, isEqual) + }) + + t.Run("returns true when the password is correct", func(t *testing.T) { + password := "mysecret" + + encryptedPassword, err := passwordEncrypter.Encrypt(ctx, password) + require.NoError(t, err) + + isEqual, err := passwordEncrypter.ComparePassword(ctx, encryptedPassword, password) + require.NoError(t, err) + + assert.True(t, isEqual) + }) +} diff --git a/stellar-auth/pkg/auth/role_manager.go b/stellar-auth/pkg/auth/role_manager.go new file mode 100644 index 000000000..7e87bcfe4 --- /dev/null +++ b/stellar-auth/pkg/auth/role_manager.go @@ -0,0 +1,154 @@ +package auth + +import ( + "context" + "fmt" + + "github.com/lib/pq" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" +) + +const defaultOwnerRoleName = "owner" + +type RoleManager interface { + GetUserRoles(ctx context.Context, user *User) ([]string, error) + // HasAllRoles validates whether the user has all roles passed by parameter. + HasAllRoles(ctx context.Context, user *User, roleNames []string) (bool, error) + // HasAnyRoles validates whether the user has one or more roles passed by parameter. + HasAnyRoles(ctx context.Context, user *User, roleNames []string) (bool, error) + IsSuperUser(ctx context.Context, user *User) (bool, error) + UpdateRoles(ctx context.Context, user *User, roleNames []string) error +} + +type userRolesInfo struct { + Roles pq.StringArray `db:"roles"` + IsOwner bool `db:"is_owner"` +} + +type defaultRoleManager struct { + dbConnectionPool db.DBConnectionPool + ownerRoleName string +} + +func (rm *defaultRoleManager) getUserRolesInfo(ctx context.Context, user *User) (*userRolesInfo, error) { + const query = ` + SELECT roles, is_owner FROM auth_users WHERE id = $1 + ` + + var ur userRolesInfo + err := rm.dbConnectionPool.GetContext(ctx, &ur, query, user.ID) + if err != nil { + return nil, fmt.Errorf("error querying user ID %s roles: %w", user.ID, err) + } + + return &ur, nil +} + +func (rm *defaultRoleManager) GetUserRoles(ctx context.Context, user *User) ([]string, error) { + ur, err := rm.getUserRolesInfo(ctx, user) + if err != nil { + return nil, err + } + + if ur.IsOwner { + return []string{rm.ownerRoleName}, nil + } + + return ur.Roles, nil +} + +func (rm *defaultRoleManager) HasAllRoles(ctx context.Context, user *User, roleNames []string) (bool, error) { + userRoles, err := rm.GetUserRoles(ctx, user) + if err != nil { + return false, err + } + + userRolesMap := make(map[string]struct{}, len(userRoles)) + for _, role := range userRoles { + userRolesMap[role] = struct{}{} + } + + for _, role := range roleNames { + if _, ok := userRolesMap[role]; !ok { + return false, nil + } + } + + return true, nil +} + +func (rm *defaultRoleManager) HasAnyRoles(ctx context.Context, user *User, roleNames []string) (bool, error) { + userRoles, err := rm.GetUserRoles(ctx, user) + if err != nil { + return false, err + } + + userRolesMap := make(map[string]struct{}, len(userRoles)) + for _, role := range userRoles { + userRolesMap[role] = struct{}{} + } + + for _, role := range roleNames { + if _, ok := userRolesMap[role]; ok { + return true, nil + } + } + + return false, nil +} + +func (rm *defaultRoleManager) IsSuperUser(ctx context.Context, user *User) (bool, error) { + ur, err := rm.getUserRolesInfo(ctx, user) + if err != nil { + return false, err + } + + return ur.IsOwner, nil +} + +func (rm *defaultRoleManager) UpdateRoles(ctx context.Context, user *User, roleNames []string) error { + const query = "UPDATE auth_users SET roles = $1 WHERE id = $2" + result, err := rm.dbConnectionPool.ExecContext(ctx, query, pq.Array(roleNames), user.ID) + if err != nil { + return fmt.Errorf("error updating user roles ID %s roles: %w", user.ID, err) + } + + numRowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("error getting number of rows affected: %w", err) + } + + if numRowsAffected == 0 { + return ErrNoRowsAffected + } + + return nil +} + +var _ RoleManager = (*defaultRoleManager)(nil) + +type defaultRoleManagerOption func(m *defaultRoleManager) + +func newDefaultRoleManager(options ...defaultRoleManagerOption) *defaultRoleManager { + defaultRoleManager := &defaultRoleManager{ + ownerRoleName: defaultOwnerRoleName, + } + + for _, option := range options { + option(defaultRoleManager) + } + + return defaultRoleManager +} + +func withRoleManagerDBConnectionPool(dbConnectionPool db.DBConnectionPool) defaultRoleManagerOption { + return func(m *defaultRoleManager) { + m.dbConnectionPool = dbConnectionPool + } +} + +func withOwnerRoleName(ownerRoleName string) defaultRoleManagerOption { + return func(m *defaultRoleManager) { + m.ownerRoleName = ownerRoleName + } +} diff --git a/stellar-auth/pkg/auth/role_manager_test.go b/stellar-auth/pkg/auth/role_manager_test.go new file mode 100644 index 000000000..8dbbc9177 --- /dev/null +++ b/stellar-auth/pkg/auth/role_manager_test.go @@ -0,0 +1,333 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_DefaultRoleManager_getUserRolesInfo(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + pe := NewDefaultPasswordEncrypter() + rm := newDefaultRoleManager(withRoleManagerDBConnectionPool(dbConnectionPool)) + + t.Run("returns correctly when user is a super user", func(t *testing.T) { + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, true) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + Roles: []string{"role1"}, + } + + ur, err := rm.getUserRolesInfo(ctx, u) + require.NoError(t, err) + + assert.True(t, ur.IsOwner) + }) + + t.Run("returns correctly when user isn't a super user", func(t *testing.T) { + roles := []string{"role1"} + + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false, roles...) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + Roles: []string{"role1"}, + } + + ur, err := rm.getUserRolesInfo(ctx, u) + require.NoError(t, err) + + assert.False(t, ur.IsOwner) + assert.Equal(t, roles, []string(ur.Roles)) + }) + + t.Run("returns correctly when user has no roles and is not super user", func(t *testing.T) { + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + Roles: []string{"role1"}, + } + + ur, err := rm.getUserRolesInfo(ctx, u) + require.NoError(t, err) + + assert.False(t, ur.IsOwner) + assert.Empty(t, ur.Roles) + }) +} + +func Test_DefaultRoleManager_GetUserRoles(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + pe := NewDefaultPasswordEncrypter() + rm := newDefaultRoleManager( + withRoleManagerDBConnectionPool(dbConnectionPool), + ) + + t.Run("returns all the roles correctly", func(t *testing.T) { + expectedRoles := []string{"role1", "role2", "role3"} + + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false, expectedRoles...) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + gotRoles, err := rm.GetUserRoles(ctx, u) + require.NoError(t, err) + + assert.Equal(t, expectedRoles, gotRoles) + }) + + t.Run("returns owner role correctly", func(t *testing.T) { + roles := []string{"role1", "role2", "role3"} + + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, true, roles...) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + gotRoles, err := rm.GetUserRoles(ctx, u) + require.NoError(t, err) + + assert.Equal(t, []string{defaultOwnerRoleName}, gotRoles) + }) +} + +func Test_DefaultRoleManager_HasAllRoles(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + pe := NewDefaultPasswordEncrypter() + rm := newDefaultRoleManager( + withRoleManagerDBConnectionPool(dbConnectionPool), + ) + + t.Run("return false when user isOwner but doesn't have the roles", func(t *testing.T) { + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, true, "role1") + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + hasRoles, err := rm.HasAllRoles(ctx, u, []string{"role1", "role2", "role3"}) + require.NoError(t, err) + + assert.False(t, hasRoles) + }) + + t.Run("validates the user roles correctly", func(t *testing.T) { + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false, "role1", "role2") + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + hasRoles, err := rm.HasAllRoles(ctx, u, []string{"role1", "role2", "role3"}) + require.NoError(t, err) + assert.False(t, hasRoles) + + hasRoles, err = rm.HasAllRoles(ctx, u, []string{"role3"}) + require.NoError(t, err) + assert.False(t, hasRoles) + + hasRoles, err = rm.HasAllRoles(ctx, u, []string{"role1"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAllRoles(ctx, u, []string{"role2"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAllRoles(ctx, u, []string{"role1", "role2"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAllRoles(ctx, u, []string{"role1", "role3"}) + require.NoError(t, err) + assert.False(t, hasRoles) + }) +} + +func Test_DefaultRoleManager_HasAnyRoles(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + pe := NewDefaultPasswordEncrypter() + rm := newDefaultRoleManager( + withRoleManagerDBConnectionPool(dbConnectionPool), + ) + + t.Run("return false when user isOwner but doesn't have the roles", func(t *testing.T) { + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, true, "role4") + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + hasRoles, err := rm.HasAnyRoles(ctx, u, []string{"role1", "role2", "role3"}) + require.NoError(t, err) + + assert.False(t, hasRoles) + }) + + t.Run("validates the user roles correctly", func(t *testing.T) { + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false, "role1", "role2") + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + hasRoles, err := rm.HasAnyRoles(ctx, u, []string{"role1", "role2", "role3"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAnyRoles(ctx, u, []string{"role3"}) + require.NoError(t, err) + assert.False(t, hasRoles) + + hasRoles, err = rm.HasAnyRoles(ctx, u, []string{"role1"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAnyRoles(ctx, u, []string{"role2"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAnyRoles(ctx, u, []string{"role1", "role2"}) + require.NoError(t, err) + assert.True(t, hasRoles) + + hasRoles, err = rm.HasAnyRoles(ctx, u, []string{"role1", "role3"}) + require.NoError(t, err) + assert.True(t, hasRoles) + }) +} + +func Test_DefaultRoleManager_IsSuperUser(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + pe := NewDefaultPasswordEncrypter() + rm := newDefaultRoleManager( + withRoleManagerDBConnectionPool(dbConnectionPool), + ) + + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false) + rauOwner := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, true) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + uo := &User{ + ID: rauOwner.ID, + Email: rauOwner.Email, + } + + isSuperUser, err := rm.IsSuperUser(ctx, u) + require.NoError(t, err) + assert.False(t, isSuperUser) + + isSuperUser, err = rm.IsSuperUser(ctx, uo) + require.NoError(t, err) + assert.True(t, isSuperUser) +} + +func Test_DefaultRoleManager_UpdateRoles(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + pe := NewDefaultPasswordEncrypter() + rm := newDefaultRoleManager( + withRoleManagerDBConnectionPool(dbConnectionPool), + ) + + rau := CreateRandomAuthUserFixture(t, ctx, dbConnectionPool, pe, false) + + u := &User{ + ID: rau.ID, + Email: rau.Email, + } + + err = rm.UpdateRoles(ctx, u, []string{"role1"}) + require.NoError(t, err) + + roles, err := rm.GetUserRoles(ctx, u) + require.NoError(t, err) + assert.Equal(t, []string{"role1"}, roles) + + err = rm.UpdateRoles(ctx, u, []string{"role1", "role2"}) + require.NoError(t, err) + + roles, err = rm.GetUserRoles(ctx, u) + require.NoError(t, err) + assert.Equal(t, []string{"role1", "role2"}, roles) + + err = rm.UpdateRoles(ctx, u, []string{"role3"}) + require.NoError(t, err) + + roles, err = rm.GetUserRoles(ctx, u) + require.NoError(t, err) + assert.Equal(t, []string{"role3"}, roles) + + err = rm.UpdateRoles(ctx, &User{ID: "user-id"}, []string{"role3"}) + assert.EqualError(t, err, ErrNoRowsAffected.Error()) +} + +func Test_withOwnerRoleName(t *testing.T) { + expectedRoleName := "my-owner-role-name" + rm := newDefaultRoleManager(withOwnerRoleName(expectedRoleName)) + assert.NotEqual(t, defaultOwnerRoleName, rm.ownerRoleName) + assert.Equal(t, expectedRoleName, rm.ownerRoleName) +} diff --git a/stellar-auth/pkg/cli/add_user.go b/stellar-auth/pkg/cli/add_user.go new file mode 100644 index 000000000..3effe8d09 --- /dev/null +++ b/stellar-auth/pkg/cli/add_user.go @@ -0,0 +1,231 @@ +package cli + +import ( + "context" + "fmt" + "go/types" + "regexp" + "strings" + + "github.com/manifoldco/promptui" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" +) + +type PasswordPromptInterface interface { + Run() (string, error) +} + +const ( + passwordMinLength = 8 + lowercasePattern = `[a-z]` + uppercasePattern = `[A-Z]` + digitsPattern = `[0-9]` + symbolsPattern = `[!@#$%^&*]` +) + +var ( + isOwner = false + passwordFlag = false +) + +func AddUserCmd(databaseURLFlagName string, passwordPrompt PasswordPromptInterface, availableRoles []string) *cobra.Command { + var rolesConfigKey []string + addUserCmdConfigOpts := config.ConfigOptions{ + { + Name: "owner", + Usage: `Set the user as Owner (superuser). Defaults to "false".`, + OptType: types.Bool, + ConfigKey: &isOwner, + FlagDefault: false, + Required: true, + }, + { + Name: "password", + Usage: "Sets the user password, it should be at least 8 characters long, if omitted, the command will generate a random one.", + OptType: types.Bool, + ConfigKey: &passwordFlag, + FlagDefault: false, + Required: false, + }, + } + + availableRolesDescription := "" + if len(availableRoles) > 0 { + availableRolesDescription = fmt.Sprintf("Available roles: [%s]", strings.Join(availableRoles, ", ")) + addUserCmdConfigOpts = append(addUserCmdConfigOpts, &config.ConfigOption{ + Name: "roles", + Usage: fmt.Sprintf("Set the user roles. It should be comma separated. Example: role1, role2. %s.", availableRolesDescription), + OptType: types.String, + CustomSetValue: setConfigOptionRoles, + ConfigKey: &rolesConfigKey, + Required: true, + }) + } + + addUser := &cobra.Command{ + Use: "add-user [--owner] [--roles] [--password]", + Short: "Add user to the system", + Long: "Add a user to the system. Email should be unique and password must be at least 8 characters long.", + Args: cobra.ExactArgs(3), + PersistentPreRun: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + + if cmd.Parent().PersistentPreRun != nil { + cmd.Parent().PersistentPreRun(cmd.Parent(), args) + // Sending this cmd to its parents' PersistentPreRun, so that it can prepare the dependencies for wrapping up this command, if needed. + cmd.Parent().PersistentPreRun(cmd, args) + } + + addUserCmdConfigOpts.Require() + err := addUserCmdConfigOpts.SetValues() + if err != nil { + log.Ctx(ctx).Fatalf("add-user error setting values of config options: %s", err.Error()) + } + + err = validateRoles(availableRoles, rolesConfigKey) + if err != nil { + log.Ctx(ctx).Fatalf("add-user error validating roles: %s", err.Error()) + } + }, + Run: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + + dbUrl := globalOptions.databaseURL + if dbUrl == "" { + dbUrl = viper.GetString(databaseURLFlagName) + } + + email, firstName, lastName := args[0], args[1], args[2] + + var password string + // If password flag is used, we prompt for a password. + // Otherwise a OTP password is generated by the Auth Manager. + if passwordFlag { + result, err := passwordPrompt.Run() + if err != nil { + log.Fatalf("add-user error prompting password: %s", err) + } + password = result + } + + err := execAddUser(ctx, dbUrl, email, firstName, lastName, password, isOwner, rolesConfigKey) + if err != nil { + log.Fatalf("add-user command error: %s", err) + } + log.Infof("user inserted: %s", args[0]) + }, + } + err := addUserCmdConfigOpts.Init(addUser) + if err != nil { + log.Fatalf("error initializing addUserCmd config option: %s", err.Error()) + } + + return addUser +} + +// NewDefaultPasswordPrompt returns the default password prompt used in add-user command. +func NewDefaultPasswordPrompt() *promptui.Prompt { + prompt := promptui.Prompt{ + Label: "Password", + Validate: PasswordPromptValidate, + Mask: ' ', + } + + return &prompt +} + +// PasswordPromptValidate validates the password input for add-user command. +func PasswordPromptValidate(input string) error { + if len(input) < passwordMinLength { + return fmt.Errorf("password must have more than %d characters", passwordMinLength) + } + + return validatePasswordCombination(input) +} + +// validatePasswordCombination returns an error if it does not consist of the four types of character requirements +func validatePasswordCombination(input string) error { + matchingPatterns := map[string]string{ + lowercasePattern: "lowercase letter", + uppercasePattern: "uppercase letter", + digitsPattern: "digit", + symbolsPattern: "symbol", + } + + const prefixErrStr = "password must contain at least one: " + errorStr := prefixErrStr + + for pattern, patternErr := range matchingPatterns { + matched, err := regexp.MatchString(pattern, input) + if err != nil { + return fmt.Errorf("error matching pattern %s", pattern) + } + if !matched { + errorStr += patternErr + ", " + } + } + + if errorStr != prefixErrStr { + return fmt.Errorf(strings.Trim(errorStr, ", ")) + } else { // even if password meets the above requirements, we still have to check for invalid characters + matchInvalidCharacters := fmt.Sprintf("^(.*%s.*%s.*%s.*%s.*)$", lowercasePattern, uppercasePattern, digitsPattern, symbolsPattern) + match, err := regexp.MatchString(matchInvalidCharacters, input) + if err != nil { + return fmt.Errorf("cannot match password to invalid characters regex: %w", err) + } + if !match { + return fmt.Errorf("password contains invalid characters") + } + } + + return nil +} + +// execAddUser creates a new user and inserts it into the database, the user will have +// it's password encrypted for security reasons. +func execAddUser(ctx context.Context, dbUrl string, email, firstName, lastName, password string, isOwner bool, roles []string) error { + dbConnectionPool, err := db.OpenDBConnectionPool(dbUrl) + if err != nil { + return fmt.Errorf("error getting dbConnectionPool in execAddUser: %w", err) + } + defer dbConnectionPool.Close() + + authManager := auth.NewAuthManager( + auth.WithDefaultAuthenticatorOption(dbConnectionPool, auth.NewDefaultPasswordEncrypter(), 0), + ) + + newUser := &auth.User{ + FirstName: firstName, + LastName: lastName, + Email: email, + IsOwner: isOwner, + Roles: roles, + } + + _, err = authManager.CreateUser(ctx, newUser, password) + if err != nil { + return fmt.Errorf("error creating user: %w", err) + } + + return nil +} + +func validateRoles(availableRoles []string, rolesConfigKey []string) error { + availableRolesMap := make(map[string]struct{}, len(availableRoles)) + for _, role := range availableRoles { + availableRolesMap[role] = struct{}{} + } + + for _, role := range rolesConfigKey { + if _, ok := availableRolesMap[role]; !ok { + return fmt.Errorf("invalid role provided. Expected one of these values: %s", strings.Join(availableRoles, " | ")) + } + } + + return nil +} diff --git a/stellar-auth/pkg/cli/add_user_test.go b/stellar-auth/pkg/cli/add_user_test.go new file mode 100644 index 000000000..f25c41b08 --- /dev/null +++ b/stellar-auth/pkg/cli/add_user_test.go @@ -0,0 +1,337 @@ +package cli + +import ( + "context" + "strings" + "testing" + + "github.com/lib/pq" + "github.com/spf13/cobra" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/pkg/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type PasswordPromptMock struct{} + +func (m *PasswordPromptMock) Run() (string, error) { + return "mockpassword", nil +} + +func Test_authAddUserCommand(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + mockPrompt := PasswordPromptMock{} + mockedPassword, _ := mockPrompt.Run() + + t.Run("Should create a new user", func(t *testing.T) { + addUser := AddUserCmd("database-url", &mockPrompt, []string{}) + rootCmd := rootCmd() + rootCmd.AddCommand(addUser) + + newEmail := "newuser@email.com" + firstName := "first" + lastName := "last" + rootCmd.SetArgs([]string{"--database-url", dbt.DSN, "add-user", newEmail, firstName, lastName, "--password"}) + err := rootCmd.Execute() + require.NoError(t, err) + + var dbEmail, dbPassword, dbFirstName, dbLastName string + var dbIsOwner bool + err = dbConnectionPool.QueryRowxContext(ctx, "SELECT email, encrypted_password, is_owner, first_name, last_name FROM auth_users WHERE email = $1", newEmail).Scan(&dbEmail, &dbPassword, &dbIsOwner, &dbFirstName, &dbLastName) + require.NoError(t, err) + + assert.Equal(t, newEmail, dbEmail) + assert.NotEqual(t, dbPassword, mockedPassword) + assert.False(t, dbIsOwner) + assert.Equal(t, firstName, dbFirstName) + assert.Equal(t, lastName, dbLastName) + }) + + t.Run("Should create a new Owner user", func(t *testing.T) { + addUser := AddUserCmd("database-url", &mockPrompt, []string{}) + rootCmd := rootCmd() + rootCmd.AddCommand(addUser) + + newEmail := "newuserowner@email.com" + firstName := "first" + lastName := "last" + rootCmd.SetArgs([]string{"--database-url", dbt.DSN, "add-user", newEmail, firstName, lastName, "--password", "--owner"}) + err := rootCmd.Execute() + require.NoError(t, err) + + var dbEmail, dbPassword, dbFirstName, dbLastName string + var dbIsOwner bool + err = dbConnectionPool.QueryRowxContext(ctx, "SELECT email, encrypted_password, is_owner, first_name, last_name FROM auth_users WHERE email = $1", newEmail).Scan(&dbEmail, &dbPassword, &dbIsOwner, &dbFirstName, &dbLastName) + require.NoError(t, err) + + assert.Equal(t, newEmail, dbEmail) + assert.NotEqual(t, dbPassword, mockedPassword) + assert.True(t, dbIsOwner) + assert.Equal(t, firstName, dbFirstName) + assert.Equal(t, lastName, dbLastName) + }) + + t.Run("Should create a new user with random generated password", func(t *testing.T) { + addUser := AddUserCmd("database-url", &mockPrompt, []string{}) + rootCmd := rootCmd() + rootCmd.AddCommand(addUser) + + newEmail := "newuserpass@email.com" + firstName := "first" + lastName := "last" + rootCmd.SetArgs([]string{"--database-url", dbt.DSN, "add-user", newEmail, firstName, lastName}) + err := rootCmd.Execute() + require.NoError(t, err) + + var dbEmail, dbPassword, dbFirstName, dbLastName string + var dbIsOwner bool + err = dbConnectionPool.QueryRowxContext(ctx, "SELECT email, encrypted_password, is_owner, first_name, last_name FROM auth_users WHERE email = $1", newEmail).Scan(&dbEmail, &dbPassword, &dbIsOwner, &dbFirstName, &dbLastName) + require.NoError(t, err) + + assert.Equal(t, newEmail, dbEmail) + assert.NotEmpty(t, dbPassword) + assert.False(t, isOwner) + }) + + t.Run("should show the correct usage", func(t *testing.T) { + setTestCmd := func() *cobra.Command { + return &cobra.Command{ + Use: "test", + } + } + + addUserCmd := AddUserCmd("database-url", &mockPrompt, []string{}) + + buf := new(strings.Builder) + testCmd := setTestCmd() + testCmd.SetOut(buf) + testCmd.AddCommand(addUserCmd) + + testCmd.SetArgs([]string{"add-user", "--help"}) + err := testCmd.Execute() + require.NoError(t, err) + + expectedUsage := `Add a user to the system. Email should be unique and password must be at least 8 characters long. + +Usage: + test add-user [--owner] [--roles] [--password] [flags] + +Flags: + -h, --help help for add-user + --owner Set the user as Owner (superuser). Defaults to "false". (OWNER) + --password Sets the user password, it should be at least 8 characters long, if omitted, the command will generate a random one. (PASSWORD) +` + assert.Equal(t, expectedUsage, buf.String()) + + addUserCmd = AddUserCmd("database-url", &mockPrompt, []string{"role1", "role2", "role3", "role4"}) + + buf = new(strings.Builder) + testCmd = setTestCmd() + testCmd.SetOut(buf) + testCmd.AddCommand(addUserCmd) + + testCmd.SetArgs([]string{"add-user", "--help"}) + err = testCmd.Execute() + require.NoError(t, err) + + expectedUsage = `Add a user to the system. Email should be unique and password must be at least 8 characters long. + +Usage: + test add-user [--owner] [--roles] [--password] [flags] + +Flags: + -h, --help help for add-user + --owner Set the user as Owner (superuser). Defaults to "false". (OWNER) + --password Sets the user password, it should be at least 8 characters long, if omitted, the command will generate a random one. (PASSWORD) + --roles string Set the user roles. It should be comma separated. Example: role1, role2. Available roles: [role1, role2, role3, role4]. (ROLES) +` + assert.Equal(t, expectedUsage, buf.String()) + }) + + t.Run("set the user roles", func(t *testing.T) { + rootCmd := rootCmd() + addUserCmd := AddUserCmd("database-url", &mockPrompt, []string{"role1", "role2"}) + rootCmd.AddCommand(addUserCmd) + + buf := new(strings.Builder) + rootCmd.SetOut(buf) + + email, firstName, lastName := "test@email.com", "First", "Last" + + rootCmd.SetArgs([]string{"--database-url", dbt.DSN, "add-user", email, firstName, lastName, "--roles", "role2"}) + err := rootCmd.Execute() + require.NoError(t, err) + + var dbUsername, dbFirstName, dbLastName string + var dbRoles []string + err = dbConnectionPool.QueryRowxContext(ctx, "SELECT email, first_name, last_name, roles FROM auth_users WHERE email = $1", email).Scan(&dbUsername, &dbFirstName, &dbLastName, pq.Array(&dbRoles)) + require.NoError(t, err) + + assert.Equal(t, email, dbUsername) + assert.Equal(t, firstName, dbFirstName) + assert.Equal(t, lastName, dbLastName) + assert.Equal(t, []string{"role2"}, dbRoles) + }) +} + +func Test_execAddUserFunc(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + + t.Run("User must be valid", func(t *testing.T) { + email, password, firstName, lastName := "test@email.com", "mypassword", "First", "Last" + + // Invalid invalid + err := execAddUser(ctx, dbt.DSN, "", firstName, lastName, password, false, []string{}) + assert.EqualError(t, err, "error creating user: error creating user: error validating user fields: email is required") + + err = execAddUser(ctx, dbt.DSN, "wrongemail", firstName, lastName, password, false, []string{}) + assert.EqualError(t, err, `error creating user: error creating user: error validating user fields: email is invalid: the provided email "wrongemail" is not valid`) + + // Invalid password + err = execAddUser(ctx, dbt.DSN, email, firstName, lastName, "pass", false, []string{}) + assert.EqualError(t, err, "error creating user: error creating user: error encrypting password: password should have at least 8 characters") + + // Invalid first name + err = execAddUser(ctx, dbt.DSN, email, "", lastName, "pass", false, []string{}) + assert.EqualError(t, err, "error creating user: error creating user: error validating user fields: first name is required") + + // Invalid last name + err = execAddUser(ctx, dbt.DSN, email, firstName, "", "pass", false, []string{}) + assert.EqualError(t, err, "error creating user: error creating user: error validating user fields: last name is required") + + // Valid user + err = execAddUser(ctx, dbt.DSN, email, firstName, lastName, password, false, []string{}) + require.NoError(t, err) + }) + + t.Run("Inserted user must have his password encrypted", func(t *testing.T) { + email, password, firstName, lastName := "test2@email.com", "mypassword", "First", "Last" + + err := execAddUser(ctx, dbt.DSN, email, firstName, lastName, password, false, []string{}) + require.NoError(t, err) + + var dbPassword string + err = dbConnectionPool.QueryRowxContext(ctx, "SELECT encrypted_password FROM auth_users WHERE email = $1", email).Scan(&dbPassword) + require.NoError(t, err) + assert.NotEqual(t, password, dbPassword) + + encrypter := auth.NewDefaultPasswordEncrypter() + + compare, err := encrypter.ComparePassword(ctx, dbPassword, password) + require.NoError(t, err) + assert.True(t, compare) + }) + + t.Run("Email should be unique", func(t *testing.T) { + email, password, firstName, lastName := "unique@email.com", "mypassword", "First", "Last" + + err := execAddUser(ctx, dbt.DSN, email, firstName, lastName, password, false, []string{}) + require.NoError(t, err) + + err = execAddUser(ctx, dbt.DSN, email, firstName, lastName, password, false, []string{}) + assert.EqualError(t, err, `error creating user: error creating user: a user with this email already exists`) + }) + + t.Run("set the user roles", func(t *testing.T) { + email, password, firstName, lastName := "testroles@email.com", "mypassword", "First", "Last" + + err := execAddUser(ctx, dbt.DSN, email, firstName, lastName, password, false, []string{"role1", "role2"}) + require.NoError(t, err) + + var dbRoles []string + err = dbConnectionPool.QueryRowxContext(ctx, "SELECT roles FROM auth_users WHERE email = $1", email).Scan(pq.Array(dbRoles)) + require.NoError(t, err) + assert.NotEqual(t, []string{"role1", "role2"}, dbRoles) + }) +} + +func Test_PasswordPromptValidateFunc(t *testing.T) { + testCases := []struct { + name string + input string + errContains []string + }{ + { + name: "returns an error if the input is less than 8 characters", + input: "test", + errContains: []string{"password must have more than 8 characters"}, + }, + { + name: "returns an error if the input does not contain all the required characters (NO uppercase letters, symbols)", + input: "test1234", + errContains: []string{"uppercase letter", "symbol"}, + }, + { + name: "return an error if the input does not contain all the required characters (NO digits)", + input: "test#ABC", + errContains: []string{"digit"}, + }, + { + name: "returns an error if the input does not contain all the required characters (NO digits, symbols)", + input: "testTEST", + errContains: []string{"digit", "symbol"}, + }, + { + name: "returns an error if the input does not contain all the required characters (NO lowercase letters, symbols)", + input: "TEST123123", + errContains: []string{"lowercase letter", "symbol"}, + }, + { + name: "returns an error if the input does not contain all the required characters (NO lowercase, uppercase letters, symbols)", + input: "1010011010", + errContains: []string{"lowercase letter", "uppercase letter", "symbol"}, + }, + { + name: "returns an error if the input contains invalid character(s) but fulfills the minimum character requirement", + input: "1Tv(^_^)vT1", + errContains: []string{"password contains invalid characters"}, + }, + { + name: "returns no error if the input is valid (happy path 1)", + input: "tEsT123#@", + }, + { + name: "returns no error if the input is valid (happy path 2)", + input: "h3LL0w0rLd$$$", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := PasswordPromptValidate(tc.input) + if tc.errContains == nil { + require.NoError(t, err) + } else { + for _, ec := range tc.errContains { + require.ErrorContains(t, err, ec) + } + } + }) + } +} + +func Test_validateRoles(t *testing.T) { + err := validateRoles([]string{"role1", "role2"}, []string{"role2", "role3"}) + assert.EqualError(t, err, "invalid role provided. Expected one of these values: role1 | role2") + + err = validateRoles([]string{"role1", "role2"}, []string{"role2", "role1"}) + assert.Nil(t, err) + + err = validateRoles([]string{}, []string{}) + assert.Nil(t, err) +} diff --git a/stellar-auth/pkg/cli/custom_set_value.go b/stellar-auth/pkg/cli/custom_set_value.go new file mode 100644 index 000000000..18796afc8 --- /dev/null +++ b/stellar-auth/pkg/cli/custom_set_value.go @@ -0,0 +1,56 @@ +package cli + +import ( + "fmt" + "strings" + + "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" +) + +func SetConfigOptionLogLevel(co *config.ConfigOption) error { + // parse string to logLevel object + logLevelStr := viper.GetString(co.Name) + logLevel, err := logrus.ParseLevel(logLevelStr) + if err != nil { + return fmt.Errorf("couldn't parse log level: %w", err) + } + + // update the configKey + key, ok := co.ConfigKey.(*logrus.Level) + if !ok { + return fmt.Errorf("configKey has an invalid type %T", co.ConfigKey) + } + *key = logLevel + + // Log for debugging + if config.IsExplicitlySet(co) { + log.Debugf("Setting log level to: %q", logLevel) + log.DefaultLogger.SetLevel(*key) + } else { + log.Debugf("Using default log level: %q", logLevel) + } + return nil +} + +func setConfigOptionRoles(co *config.ConfigOption) error { + rolesStr := viper.GetString(co.Name) + rolesSplit := strings.FieldsFunc(rolesStr, func(r rune) bool { + return r == ',' + }) + + roles := make([]string, 0, len(rolesSplit)) + for _, role := range rolesSplit { + roles = append(roles, strings.TrimSpace(role)) + } + + key, ok := co.ConfigKey.(*[]string) + if !ok { + return fmt.Errorf("the expected type for this config key is a string slice, but got a %T instead", co.ConfigKey) + } + *key = roles + + return nil +} diff --git a/stellar-auth/pkg/cli/custom_set_value_test.go b/stellar-auth/pkg/cli/custom_set_value_test.go new file mode 100644 index 000000000..499193f18 --- /dev/null +++ b/stellar-auth/pkg/cli/custom_set_value_test.go @@ -0,0 +1,160 @@ +package cli + +import ( + "go/types" + "testing" + + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/stellar/go/support/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_SetConfigOptionLogLevel(t *testing.T) { + co := config.ConfigOption{ + Name: "log-level", + OptType: types.String, + CustomSetValue: SetConfigOptionLogLevel, + } + + executeCmd := func(args []string, handleError func(err error)) { + // mock a command line argument + testCmd := cobra.Command{ + Run: func(cmd *cobra.Command, args []string) { + co.Require() + // forward error to the error handler callback: + handleError(co.SetValue()) + }, + } + err := co.Init(&testCmd) + require.NoError(t, err) + + // execute command line + testCmd.SetArgs(args) + err = testCmd.Execute() + require.NoError(t, err) + } + + // invalid log level should return an error + testCount := 0 + executeCmd([]string{"--log-level", "aaa"}, func(err error) { + require.EqualError(t, err, `couldn't parse log level: not a valid logrus Level: "aaa"`) + testCount++ + }) + require.Equal(t, 1, testCount) + + // misconfigured configKey should return an error + executeCmd([]string{"--log-level", "info"}, func(err error) { + require.EqualError(t, err, `configKey has an invalid type `) + testCount++ + }) + require.Equal(t, 2, testCount) + + // valid log level should set the configKey + var logrusLevel logrus.Level + require.NotEqual(t, logrus.InfoLevel, logrusLevel) + co.ConfigKey = &logrusLevel + executeCmd([]string{"--log-level", "info"}, func(err error) { + require.NoError(t, err) + testCount++ + }) + require.Equal(t, 3, testCount) + require.Equal(t, logrus.InfoLevel, logrusLevel) + + // If no value is passed, stick with the default ("TRACE") + co.FlagDefault = "TRACE" + require.NotEqual(t, logrus.TraceLevel, logrusLevel) + executeCmd(nil, func(err error) { + require.NoError(t, err) + testCount++ + }) + require.Equal(t, 4, testCount) + require.Equal(t, logrus.TraceLevel, logrusLevel) +} + +func Test_setConfigOptionRoles(t *testing.T) { + var rolesConfigKey []string + + co := config.ConfigOption{ + Name: "roles", + OptType: types.String, + CustomSetValue: setConfigOptionRoles, + ConfigKey: &rolesConfigKey, + } + + executeCmd := func(args []string, handleError func(err error)) { + // mock a command line argument + testCmd := cobra.Command{ + Run: func(cmd *cobra.Command, args []string) { + co.Require() + // forward error to the error handler callback: + handleError(co.SetValue()) + }, + } + err := co.Init(&testCmd) + require.NoError(t, err) + + // execute command line + testCmd.SetArgs(args) + err = testCmd.Execute() + require.NoError(t, err) + } + + t.Run("handles set the roles through the CLI flag", func(t *testing.T) { + testCount := 0 + executeCmd([]string{"--roles", "role1, role2, role3"}, func(err error) { + require.NoError(t, err) + testCount++ + }) + + assert.Equal(t, []string{"role1", "role2", "role3"}, rolesConfigKey) + + executeCmd([]string{"--roles", "role1,role2,role3"}, func(err error) { + require.NoError(t, err) + testCount++ + }) + + assert.Equal(t, []string{"role1", "role2", "role3"}, rolesConfigKey) + + executeCmd([]string{"--roles", ""}, func(err error) { + require.NoError(t, err) + testCount++ + }) + + assert.Equal(t, []string{}, rolesConfigKey) + assert.Equal(t, 3, testCount) + }) + + t.Run("handles set the roles through Env Vars", func(t *testing.T) { + testCount := 0 + + t.Setenv("ROLES", "role1, role2, role3") + + executeCmd([]string{}, func(err error) { + require.NoError(t, err) + testCount++ + }) + + assert.Equal(t, []string{"role1", "role2", "role3"}, rolesConfigKey) + + t.Setenv("ROLES", "role1,role2,role3") + + executeCmd([]string{}, func(err error) { + require.NoError(t, err) + testCount++ + }) + + assert.Equal(t, []string{"role1", "role2", "role3"}, rolesConfigKey) + + t.Setenv("ROLES", "") + + executeCmd([]string{"--roles", ""}, func(err error) { + require.NoError(t, err) + testCount++ + }) + + assert.Equal(t, []string{}, rolesConfigKey) + assert.Equal(t, 3, testCount) + }) +} diff --git a/stellar-auth/pkg/cli/migrate.go b/stellar-auth/pkg/cli/migrate.go new file mode 100644 index 000000000..280b836ed --- /dev/null +++ b/stellar-auth/pkg/cli/migrate.go @@ -0,0 +1,91 @@ +package cli + +import ( + "fmt" + "strconv" + + migrate "github.com/rubenv/sql-migrate" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stellar/go/support/log" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" +) + +func MigrateCmd(databaseFlagName string) *cobra.Command { + migrateCmd := &cobra.Command{ + Use: "migrate", + Short: "Apply Stellar Auth database migrations", + Run: func(cmd *cobra.Command, args []string) { + if err := cmd.Help(); err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + }, + } + + migrateUp := &cobra.Command{ + Use: "up [count]", + Short: "Migrates database up [count]", + Args: cobra.MaximumNArgs(1), + Run: func(cmd *cobra.Command, args []string) { + var count int + if len(args) > 0 { + var err error + count, err = strconv.Atoi(args[0]) + if err != nil { + log.Fatalf("Invalid [count] argument: %s", args[0]) + } + } + + dbURL := globalOptions.databaseURL + if globalOptions.databaseURL == "" { + dbURL = viper.GetString(databaseFlagName) + } + + err := runMigration(dbURL, migrate.Up, count) + if err != nil { + log.Fatalf("Error migrating database Up: %s", err.Error()) + } + }, + } + migrateCmd.AddCommand(migrateUp) + + migrateDown := &cobra.Command{ + Use: "down [count]", + Short: "Migrates database down [count] migrations", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + count, err := strconv.Atoi(args[0]) + if err != nil { + log.Fatalf("Invalid [count] argument: %s", args[0]) + } + + dbURL := globalOptions.databaseURL + if globalOptions.databaseURL == "" { + dbURL = viper.GetString(databaseFlagName) + } + + err = runMigration(dbURL, migrate.Down, count) + if err != nil { + log.Fatalf("Error migrating database Down: %s", err.Error()) + } + }, + } + migrateCmd.AddCommand(migrateDown) + + return migrateCmd +} + +func runMigration(databaseURL string, dir migrate.MigrationDirection, count int) error { + numMigrationsRun, err := db.Migrate(databaseURL, dir, count) + if err != nil { + return fmt.Errorf("running migrations: %w", err) + } + + if numMigrationsRun == 0 { + log.Info("No migrations applied.") + } else { + log.Infof("Successfully applied %d migrations.", numMigrationsRun) + } + + return nil +} diff --git a/stellar-auth/pkg/cli/migrate_test.go b/stellar-auth/pkg/cli/migrate_test.go new file mode 100644 index 000000000..5333ba139 --- /dev/null +++ b/stellar-auth/pkg/cli/migrate_test.go @@ -0,0 +1,217 @@ +package cli + +import ( + "context" + "database/sql" + "fmt" + "strings" + "testing" + + migrate "github.com/rubenv/sql-migrate" + "github.com/spf13/cobra" + "github.com/spf13/viper" + stellardbtest "github.com/stellar/go/support/db/dbtest" + "github.com/stellar/go/support/log" + dbpkg "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-auth/internal/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getMigrationsApplied(t *testing.T, ctx context.Context, db *sql.DB) []string { + rows, err := db.QueryContext(ctx, fmt.Sprintf("SELECT id FROM %s", dbpkg.StellarAuthMigrationsTableName)) + require.NoError(t, err) + + defer rows.Close() + + ids := []string{} + for rows.Next() { + var id string + err := rows.Scan(&id) + require.NoError(t, err) + + ids = append(ids, id) + } + + require.NoError(t, rows.Err()) + + return ids +} + +func Test_MigrateCmd(t *testing.T) { + testCases := []struct { + name string + args []string + envVars map[string]string + expect string + expectError string + preRunFunc func(*testing.T, *stellardbtest.DB) + postRunFunc func(*sql.DB) + }{ + { + name: "test help command", + args: []string{"migrate", "--help"}, + expect: "Apply Stellar Auth database migrations\n\nUsage:\n stellarauth migrate [flags]\n stellarauth migrate [command]\n\nAvailable Commands:\n down Migrates database down [count] migrations\n up Migrates database up [count]\n\nFlags:\n -h, --help help for migrate\n\nGlobal Flags:\n --database-url string Postgres DB URL (DATABASE_URL) (default \"postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable\")\n --log-level string The log level used in this project. Options: \"TRACE\", \"DEBUG\", \"INFO\", \"WARN\", \"ERROR\", \"FATAL\", or \"PANIC\". (LOG_LEVEL) (default \"TRACE\")\n\nUse \"stellarauth migrate [command] --help\" for more information about a command.\n", + }, + { + name: "test short help command", + args: []string{"migrate", "-h"}, + expect: "Apply Stellar Auth database migrations\n\nUsage:\n stellarauth migrate [flags]\n stellarauth migrate [command]\n\nAvailable Commands:\n down Migrates database down [count] migrations\n up Migrates database up [count]\n\nFlags:\n -h, --help help for migrate\n\nGlobal Flags:\n --database-url string Postgres DB URL (DATABASE_URL) (default \"postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable\")\n --log-level string The log level used in this project. Options: \"TRACE\", \"DEBUG\", \"INFO\", \"WARN\", \"ERROR\", \"FATAL\", or \"PANIC\". (LOG_LEVEL) (default \"TRACE\")\n\nUse \"stellarauth migrate [command] --help\" for more information about a command.\n", + }, + { + name: "test migrate up successfully", + args: []string{"--log-level", "TRACE", "--database-url", "", "migrate", "up", "1"}, + expect: "Successfully applied 1 migrations.", + postRunFunc: func(db *sql.DB) { + ids := getMigrationsApplied(t, context.Background(), db) + assert.Equal(t, []string{"2023-02-09.0.add-users-table.sql"}, ids) + }, + }, + { + name: "test migrate up successfully when using the DATABASE_URL env var", + args: []string{"--log-level", "TRACE", "migrate", "up", "1"}, + envVars: map[string]string{"DATABASE_URL": ""}, + expect: "Successfully applied 1 migrations.", + postRunFunc: func(db *sql.DB) { + ids := getMigrationsApplied(t, context.Background(), db) + assert.Equal(t, []string{"2023-02-09.0.add-users-table.sql"}, ids) + }, + }, + { + name: "test apply migrations when no number of migration is specified", + args: []string{"--log-level", "TRACE", "--database-url", "", "migrate", "up"}, + expect: "Successfully applied", + expectError: "", + }, + { + name: "test migrate down usage", + args: []string{"migrate", "down"}, + expect: "Usage:\n stellarauth migrate down [count] [flags]\n\nFlags:\n -h, --help help for down\n\nGlobal Flags:\n --database-url string Postgres DB URL (DATABASE_URL) (default \"postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable\")\n --log-level string The log level used in this project. Options: \"TRACE\", \"DEBUG\", \"INFO\", \"WARN\", \"ERROR\", \"FATAL\", or \"PANIC\". (LOG_LEVEL) (default \"TRACE\")\n\n", + expectError: "accepts 1 arg(s), received 0", + }, + { + name: "test migrate up successfully", + args: []string{"--log-level", "TRACE", "--database-url", "", "migrate", "down", "1"}, + expect: "Successfully applied 1 migrations.", + preRunFunc: func(t *testing.T, db *stellardbtest.DB) { + _, err := dbpkg.Migrate(db.DSN, migrate.Up, 1) + require.NoError(t, err) + + conn := db.Open() + defer conn.Close() + + ids := getMigrationsApplied(t, context.Background(), conn.DB) + assert.Equal(t, []string{"2023-02-09.0.add-users-table.sql"}, ids) + }, + postRunFunc: func(db *sql.DB) { + ids := getMigrationsApplied(t, context.Background(), db) + assert.Equal(t, []string{}, ids) + }, + }, + { + name: "test migrate up successfully when using the DATABASE_URL env var", + args: []string{"--log-level", "TRACE", "migrate", "down", "1"}, + envVars: map[string]string{"DATABASE_URL": ""}, + expect: "Successfully applied 1 migrations.", + preRunFunc: func(t *testing.T, db *stellardbtest.DB) { + _, err := dbpkg.Migrate(db.DSN, migrate.Up, 1) + require.NoError(t, err) + + conn := db.Open() + defer conn.Close() + + ids := getMigrationsApplied(t, context.Background(), conn.DB) + assert.Equal(t, []string{"2023-02-09.0.add-users-table.sql"}, ids) + }, + postRunFunc: func(db *sql.DB) { + ids := getMigrationsApplied(t, context.Background(), db) + assert.Equal(t, []string{}, ids) + }, + }, + } + + for _, tc := range testCases { + db := dbtest.OpenWithoutMigrations(t) + + if len(tc.args) >= 3 && tc.args[2] == "--database-url" { + tc.args[3] = db.DSN + } + + t.Run(tc.name, func(t *testing.T) { + if tc.preRunFunc != nil { + tc.preRunFunc(t, db) + } + + for key, value := range tc.envVars { + if key == "DATABASE_URL" { + value = db.DSN + } + t.Setenv(key, value) + } + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + + rootCmd := rootCmd() + rootCmd.SetOut(buf) + rootCmd.AddCommand(MigrateCmd("")) + rootCmd.SetArgs(tc.args) + + err := rootCmd.Execute() + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + require.NoError(t, err) + } + + output := buf.String() + if tc.expect != "" { + assert.Contains(t, output, tc.expect) + } + + if tc.postRunFunc != nil { + conn := db.Open() + tc.postRunFunc(conn.DB) + conn.Close() + } + }) + + db.Close() + } +} + +func Test_MigrateCmd_databaseFlagName(t *testing.T) { + globalOptions = globalOptionsType{} + + dbt := dbtest.OpenWithoutMigrations(t) + defer dbt.Close() + + testCmd := &cobra.Command{ + Use: "testcmd", + Run: func(cmd *cobra.Command, args []string) { + err := cmd.Help() + require.NoError(t, err) + }, + } + + testCmd.PersistentFlags().String("db-url", dbt.DSN, "") + + err := viper.BindPFlag("db-url", testCmd.PersistentFlags().Lookup("db-url")) + require.NoError(t, err) + + err = viper.BindEnv("DB_URL", dbt.DSN) + require.NoError(t, err) + + testCmd.AddCommand(MigrateCmd("db-url")) + testCmd.SetArgs([]string{"migrate", "up", "1"}) + + buf := new(strings.Builder) + log.DefaultLogger.SetOutput(buf) + log.DefaultLogger.SetLevel(log.InfoLevel) + testCmd.SetOut(buf) + + err = testCmd.Execute() + require.NoError(t, err) + + assert.Contains(t, buf.String(), "Successfully applied 1 migrations.") +} diff --git a/stellar-auth/pkg/cli/root.go b/stellar-auth/pkg/cli/root.go new file mode 100644 index 000000000..a6bc1df76 --- /dev/null +++ b/stellar-auth/pkg/cli/root.go @@ -0,0 +1,79 @@ +package cli + +import ( + "go/types" + + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/stellar/go/support/config" + "github.com/stellar/go/support/log" +) + +type globalOptionsType struct { + version string + gitCommit string + databaseURL string + logLevel logrus.Level +} + +var globalOptions globalOptionsType + +func rootCmd() *cobra.Command { + configOptions := config.ConfigOptions{ + { + Name: "log-level", + Usage: `The log level used in this project. Options: "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", or "PANIC".`, + OptType: types.String, + FlagDefault: "TRACE", + ConfigKey: &globalOptions.logLevel, + CustomSetValue: SetConfigOptionLogLevel, + Required: true, + }, + { + Name: "database-url", + Usage: "Postgres DB URL", + OptType: types.String, + FlagDefault: "postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable", + ConfigKey: &globalOptions.databaseURL, + Required: true, + }, + } + + cmd := &cobra.Command{ + Use: "stellarauth", + Short: "Stellar Auth handles JWT management.", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + configOptions.Require() + err := configOptions.SetValues() + if err != nil { + log.Fatalf("Error setting values of config options: %s", err.Error()) + } + + log.Info("Version: ", globalOptions.version) + log.Info("GitCommit: ", globalOptions.gitCommit) + }, + Run: func(cmd *cobra.Command, args []string) { + if err := cmd.Help(); err != nil { + log.Fatalf("Error calling help command: %s", err.Error()) + } + }, + } + + if err := configOptions.Init(cmd); err != nil { + log.Fatalf("Error initializing a config option: %s", err.Error()) + } + + return cmd +} + +func SetupCLI(version, gitCommit string) *cobra.Command { + globalOptions.version = version + globalOptions.gitCommit = gitCommit + + cmd := rootCmd() + + cmd.AddCommand(MigrateCmd("")) + cmd.AddCommand(AddUserCmd("", NewDefaultPasswordPrompt(), []string{})) + + return cmd +} diff --git a/stellar-auth/pkg/cli/root_test.go b/stellar-auth/pkg/cli/root_test.go new file mode 100644 index 000000000..e1db43a94 --- /dev/null +++ b/stellar-auth/pkg/cli/root_test.go @@ -0,0 +1,86 @@ +package cli + +import ( + "strings" + "testing" + + "github.com/stellar/go/support/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_rootCmd(t *testing.T) { + testCases := []struct { + name string + args []string + envVars map[string]string + expect string + notExpect string + }{ + { + name: "test help command", + args: []string{"--help"}, + expect: "Stellar Auth handles JWT management.\n\nUsage:\n stellarauth [flags]\n\nFlags:\n --database-url string Postgres DB URL (DATABASE_URL) (default \"postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable\")\n -h, --help help for stellarauth\n --log-level string The log level used in this project. Options: \"TRACE\", \"DEBUG\", \"INFO\", \"WARN\", \"ERROR\", \"FATAL\", or \"PANIC\". (LOG_LEVEL) (default \"TRACE\")\n", + }, + { + name: "test short help command", + args: []string{"-h"}, + expect: "Stellar Auth handles JWT management.\n\nUsage:\n stellarauth [flags]\n\nFlags:\n --database-url string Postgres DB URL (DATABASE_URL) (default \"postgres://postgres:postgres@localhost:5432/stellar-auth?sslmode=disable\")\n -h, --help help for stellarauth\n --log-level string The log level used in this project. Options: \"TRACE\", \"DEBUG\", \"INFO\", \"WARN\", \"ERROR\", \"FATAL\", or \"PANIC\". (LOG_LEVEL) (default \"TRACE\")\n", + }, + { + name: "test set log-level", + args: []string{"--log-level", "INFO"}, + expect: "msg=\"GitCommit: \"", + }, + { + name: "test set log-level with WARN level and doesn't logs INFO messages", + args: []string{"--log-level", "WARN"}, + expect: "", + notExpect: "msg=\"GitCommit: \"", + }, + { + name: "test set database-url", + args: []string{"--log-level", "WARN", "--database-url", "postgres://localhost@5432/stellar-auth?sslmode=disable"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } + + rootCmd := rootCmd() + rootCmd.SetArgs(tc.args) + + buf := new(strings.Builder) + + log.DefaultLogger.SetOutput(buf) + rootCmd.SetOut(buf) + + err := rootCmd.Execute() + require.NoError(t, err) + + output := buf.String() + if tc.expect != "" { + assert.Contains(t, output, tc.expect) + } + + if tc.notExpect != "" { + assert.NotContains(t, output, tc.notExpect) + } + }) + } +} + +func Test_SetupCLI(t *testing.T) { + cmd := SetupCLI("v0.0.1", "a1b2c3d4") + + buf := new(strings.Builder) + cmd.SetOut(buf) + + err := cmd.Execute() + require.NoError(t, err) + + assert.Contains(t, buf.String(), "migrate Apply Stellar Auth database migrations") +} diff --git a/stellar-auth/pkg/utils/utils.go b/stellar-auth/pkg/utils/utils.go new file mode 100644 index 000000000..f2e6b9180 --- /dev/null +++ b/stellar-auth/pkg/utils/utils.go @@ -0,0 +1,52 @@ +package utils + +import ( + "crypto/rand" + "fmt" + "math/big" + "regexp" +) + +const ( + // Default charset to be used with StringWithCharset function + DefaultCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + SpecialCharset = "!@#$%&*+-_" + // Password charset adds special chars + PasswordCharset = DefaultCharset + SpecialCharset +) + +// Generates a random string with the charset infromed and the length +func StringWithCharset(length int, charset string) (string, error) { + b := make([]byte, length) + for i := range b { + randomNumber, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + return "", fmt.Errorf("error generating random number in StringWithCharset: %w", err) + } + b[i] = charset[randomNumber.Int64()] + } + return string(b), nil +} + +// RxEmail is a regex used to validate e-mail addresses, according with the reference https://www.alexedwards.net/blog/validation-snippets-for-go#email-validation. +// It's free to use under the [MIT License](https://opensource.org/licenses/MIT) +var rxEmail = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") + +func ValidateEmail(email string) error { + if email == "" { + return fmt.Errorf("email cannot be empty") + } + + if !rxEmail.MatchString(email) { + return fmt.Errorf("the provided email %q is not valid", email) + } + + return nil +} + +func TruncateString(str string, borderSizeToKeep int) string { + if len(str) <= 2*borderSizeToKeep { + return str + } + return str[:borderSizeToKeep] + "..." + str[len(str)-borderSizeToKeep:] +} diff --git a/stellar-auth/pkg/utils/utils_test.go b/stellar-auth/pkg/utils/utils_test.go new file mode 100644 index 000000000..7833d2091 --- /dev/null +++ b/stellar-auth/pkg/utils/utils_test.go @@ -0,0 +1,81 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_StringWithCharsetLenght(t *testing.T) { + charset := "asdfghjklzxcvbnm" + tokenLength := 4 + + token, err := StringWithCharset(tokenLength, charset) + require.NoError(t, err) + token2, err := StringWithCharset(tokenLength, charset) + require.NoError(t, err) + assert.Len(t, token, tokenLength) + assert.NotEqual(t, token, token2) +} + +func Test_ValidateEmail(t *testing.T) { + testCases := []struct { + email string + wantErr error + }{ + {"", fmt.Errorf("email cannot be empty")}, + {"notvalidemail", fmt.Errorf(`the provided email "notvalidemail" is not valid`)}, + {"valid@test.com", nil}, + {"valid+email@test.com", nil}, + } + + for _, tc := range testCases { + t.Run(tc.email, func(t *testing.T) { + gotError := ValidateEmail(tc.email) + assert.Equalf(t, tc.wantErr, gotError, "ValidateEmail(%q) should be %v, but got %v", tc.email, tc.wantErr, gotError) + }) + } +} + +func Test_TruncateString(t *testing.T) { + testCases := []struct { + name string + rawString string + borderSizeToKeep int + wantTruncated string + }{ + { + name: "string is shorter than borderSizeToKeep", + rawString: "abc", + borderSizeToKeep: 4, + wantTruncated: "abc", + }, + { + name: "string is longer than borderSizeToKeep", + rawString: "abcdefg", + borderSizeToKeep: 3, + wantTruncated: "abc...efg", + }, + { + name: "string is same length as borderSizeToKeep", + rawString: "abcdef", + borderSizeToKeep: 3, + wantTruncated: "abcdef", + }, + { + name: "string is empty", + rawString: "", + borderSizeToKeep: 3, + wantTruncated: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotTruncated := TruncateString(tc.rawString, tc.borderSizeToKeep) + assert.Equal(t, tc.wantTruncated, gotTruncated, "Expected Truncate(%q, %d) to be %q, but got %q", tc.rawString, tc.borderSizeToKeep, tc.wantTruncated, gotTruncated) + }) + } +} diff --git a/v1_compatibility/database_migration_compatibility.sh b/v1_compatibility/database_migration_compatibility.sh new file mode 100755 index 000000000..a81f9bdb8 --- /dev/null +++ b/v1_compatibility/database_migration_compatibility.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# This script is used to locally run the integration tests for compatibility between SDP-v1 and SDP-v2 +set -eu + +export DIVIDER="----------------------------------------" + +# prepare +rm -rf stellar-relief-backoffice-backend +docker ps -aq | xargs docker stop | xargs docker rm + +# Clone SDP v1 +echo $DIVIDER +echo "====> πŸ‘€Step 1: start cloning SDP v1 (stellar/stellar-relief-backoffice-backend)" +git clone -b main git@github.com:stellar/stellar-relief-backoffice-backend.git +echo "====> βœ…Step 1: finish cloning SDP v1 (stellar/stellar-relief-backoffice-backend)" + +# Run docker compose +echo $DIVIDER +echo "====> πŸ‘€Step 2: start calling docker compose up" +docker compose down && docker-compose up --abort-on-container-exit +echo "====> βœ…Step 2: finish calling docker-compose up" + +echo $DIVIDER +echo "πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰ SUCCESS! πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰" \ No newline at end of file diff --git a/v1_compatibility/docker-compose.yml b/v1_compatibility/docker-compose.yml new file mode 100644 index 000000000..72fee9af5 --- /dev/null +++ b/v1_compatibility/docker-compose.yml @@ -0,0 +1,52 @@ +version: '3' +services: + db: + image: postgres:14-alpine + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + PGPORT: 5432 + ports: + - "5432:5432" + + sdp-v1: + image: stellar/sdp-v1:latest + build: + context: ./stellar-relief-backoffice-backend/ + dockerfile: Dockerfile + environment: + DATABASE_URL: postgres://postgres:postgres@db:5432/postgres?sslmode=disable + DJANGO_SECRET_KEY: xxx + CIRCLE_API_KEY: foo + CIRCLE_WALLET_ID: foo + FILE_SERVER_HOST: foo + FILE_SERVER_UNREAD_PATH: foo + FILE_SERVER_READ_PATH: foo + FILE_SERVER_RECEIPTS_PATH: foo + command: + - sh + - -c + - | + pipenv run python manage.py migrate --settings payments.settings + sleep 30 + depends_on: + - db + + sdp-v2: + image: stellar/sdp-v2:latest + build: + context: ../ + dockerfile: Dockerfile + environment: + DATABASE_URL: postgres://postgres:postgres@db:5432/postgres?sslmode=disable + entrypoint: "" + command: + - sh + - -c + - | + sleep 5 + ./stellar-disbursement-platform db migrate up + depends_on: + - db + - sdp-v1