Skip to content

Commit

Permalink
[PECO-2050] Add custom auth headers into cloud fetch request (#249)
Browse files Browse the repository at this point in the history
When file encryption is enabled with customer provided keys (SSE-CPK),
we must pass the keys in HTTP headers in the fetch request. These
headers are provided in the property `httpHeaders` in the
`TSparkArrowResultLink`
  • Loading branch information
jackyhu-db authored Oct 25, 2024
2 parents 1e9d6ac + 977c5c1 commit db03838
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
6 changes: 6 additions & 0 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ func fetchBatchBytes(
return nil, err
}

if link.HttpHeaders != nil {
for key, value := range link.HttpHeaders {
req.Header.Set(key, value)
}
}

client := http.DefaultClient
res, err := client.Do(req)
if err != nil {
Expand Down
21 changes: 18 additions & 3 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ import (
"bytes"
"context"
"fmt"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
"net/http"
"net/http/httptest"
"testing"
"time"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/pkg/errors"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
Expand All @@ -28,8 +30,19 @@ func TestCloudFetchIterator(t *testing.T) {
defer server.Close()

t.Run("should fetch all the links", func(t *testing.T) {
cloudFetchHeaders := map[string]string{
"foo": "bar",
}

handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
for name, value := range cloudFetchHeaders {
if values, ok := r.Header[name]; ok {
if values[0] != value {
panic(errors.New("Missing auth headers"))
}
}
}
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
if err != nil {
panic(err)
Expand All @@ -44,12 +57,14 @@ func TestCloudFetchIterator(t *testing.T) {
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
HttpHeaders: cloudFetchHeaders,
},
{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset + 1,
RowCount: 1,
HttpHeaders: cloudFetchHeaders,
},
}

Expand Down

0 comments on commit db03838

Please sign in to comment.