From 92d1ef422f03d055456fac1bad7a589da8605f61 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 4 Jul 2024 19:47:15 +0300 Subject: [PATCH 1/7] Refactoring: Hide Arrow and CloudFetch batch loaders behind BatchIterator interface (simplifies usage and encapsulates implementation details) Signed-off-by: Levko Kravets --- .../rows/arrowbased/arrowRecordIterator.go | 23 ++------------- internal/rows/arrowbased/arrowRows.go | 12 ++------ internal/rows/arrowbased/batchloader.go | 29 +++++++++++++++---- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index 898d0a45..583cbd04 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -163,29 +163,10 @@ func (ri *arrowRecordIterator) getBatchIterator() error { // Create a new batch iterator from a page of the result set func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) { - bl, err := ri.newBatchLoader(fr) - if err != nil { - return nil, err - } - - bi, err := NewBatchIterator(bl) - - return bi, err -} - -// Create a new batch loader from a page of the result set -func (ri *arrowRecordIterator) newBatchLoader(fr *cli_service.TFetchResultsResp) (BatchLoader, error) { rowSet := fr.Results - var bl BatchLoader - var err error if len(rowSet.ResultLinks) > 0 { - bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) + return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) } else { - bl, err = NewLocalBatchLoader(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) + return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } - if err != nil { - return nil, err - } - - return bl, nil } diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 89fe9b94..f6a60c58 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -112,27 +112,21 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp return nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsToTimestampFn, err) } - var bl BatchLoader + var bi BatchIterator var err2 dbsqlerr.DBError if len(rowSet.ResultLinks) > 0 { logger.Debug().Msgf("Initialize CloudFetch loader, row set start offset: %d, file list:", rowSet.StartRowOffset) for _, resultLink := range rowSet.ResultLinks { logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount) } - bl, err2 = NewCloudBatchLoader(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) - logger.Debug().Msgf("Created CloudFetch concurrent loader, rows range [%d..%d]", bl.Start(), bl.End()) + bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) } else { - bl, err2 = NewLocalBatchLoader(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) + bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) } if err2 != nil { return nil, err2 } - bi, err := NewBatchIterator(bl) - if err != nil { - return nil, err2 - } - var location *time.Location = time.UTC if cfg != nil { if cfg.Location != nil { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index e36153ab..85c33b50 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -33,16 +33,35 @@ type BatchLoader interface { Close() } -func NewBatchIterator(batchLoader BatchLoader) (BatchIterator, dbsqlerr.DBError) { +func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) { + bl, err := newCloudBatchLoader(ctx, files, startRowOffset, cfg) + if err != nil { + return nil, err + } + + bi := &batchIterator{ + nextBatchStart: bl.Start(), + batchLoader: bl, + } + + return bi, nil +} + +func NewLocalBatchIterator(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) { + bl, err := newLocalBatchLoader(ctx, batches, startRowOffset, arrowSchemaBytes, cfg) + if err != nil { + return nil, err + } + bi := &batchIterator{ - nextBatchStart: batchLoader.Start(), - batchLoader: batchLoader, + nextBatchStart: bl.Start(), + batchLoader: bl, } return bi, nil } -func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() @@ -79,7 +98,7 @@ func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowRe return cbl, nil } -func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { +func newLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() From 72540a950ee3c40b1e0613ece9542e80b9dfb320 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Tue, 9 Jul 2024 13:38:10 +0300 Subject: [PATCH 2/7] Split out regular Arrow results handling from CloudFetch Signed-off-by: Levko Kravets --- internal/rows/arrowbased/batchloader.go | 125 ++++++++++++++---------- 1 file changed, 76 insertions(+), 49 deletions(-) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 85c33b50..36dc2022 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -33,7 +33,12 @@ type BatchLoader interface { Close() } -func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) { +func NewCloudBatchIterator( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { bl, err := newCloudBatchLoader(ctx, files, startRowOffset, cfg) if err != nil { return nil, err @@ -47,21 +52,78 @@ func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrow return bi, nil } -func NewLocalBatchIterator(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) { - bl, err := newLocalBatchLoader(ctx, batches, startRowOffset, arrowSchemaBytes, cfg) - if err != nil { - return nil, err +func NewLocalBatchIterator( + ctx context.Context, + batches []*cli_service.TSparkArrowBatch, + startRowOffset int64, + arrowSchemaBytes []byte, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { + bi := &localBatchIterator{ + useLz4Compression: cfg.UseLz4Compression, + startRowOffset: startRowOffset, + arrowSchemaBytes: arrowSchemaBytes, + batches: batches, + index: -1, } - bi := &batchIterator{ - nextBatchStart: bl.Start(), - batchLoader: bl, + return bi, nil +} + +type localBatchIterator struct { + useLz4Compression bool + startRowOffset int64 + arrowSchemaBytes []byte + batches []*cli_service.TSparkArrowBatch + index int +} + +var _ BatchIterator = (*localBatchIterator)(nil) + +func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { + cnt := len(bi.batches) + bi.index++ + if bi.index < cnt { + ab := bi.batches[bi.index] + + reader := io.MultiReader( + bytes.NewReader(bi.arrowSchemaBytes), + getReader(bytes.NewReader(ab.Batch), bi.useLz4Compression), + ) + + records, err := getArrowRecords(reader, bi.startRowOffset) + if err != nil { + return &sparkArrowBatch{}, err + } + + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, ab.RowCount), + arrowRecords: records, + } + + bi.startRowOffset += ab.RowCount // advance to beginning of the next batch + + return &batch, nil } - return bi, nil + bi.index = cnt + return nil, io.EOF } -func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +func (bi *localBatchIterator) HasNext() bool { + return bi.index < len(bi.batches) +} + +func (bi *localBatchIterator) Close() { + bi.index = len(bi.batches) +} + +func newCloudBatchLoader( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, +) (*batchLoader[*cloudURL], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() @@ -98,41 +160,6 @@ func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowRe return cbl, nil } -func newLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { - - if cfg == nil { - cfg = config.WithDefaults() - } - - var startRow int64 = startRowOffset - var rowCount int64 - inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(batches)) - for i := range batches { - b := batches[i] - if b != nil { - li := &localBatch{ - Delimiter: rowscanner.NewDelimiter(startRow, b.RowCount), - batchBytes: b.Batch, - arrowSchemaBytes: arrowSchemaBytes, - compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, - } - inputChan <- li - startRow = startRow + b.RowCount - rowCount += b.RowCount - } - } - close(inputChan) - - f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) - cbl := &batchLoader[*localBatch]{ - Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), - fetcher: f, - ctx: ctx, - } - - return cbl, nil -} - type batchLoader[T interface { Fetch(ctx context.Context) (SparkArrowBatch, error) }] struct { @@ -199,8 +226,8 @@ type compressibleBatch struct { useLz4Compression bool } -func (cb compressibleBatch) getReader(r io.Reader) io.Reader { - if cb.useLz4Compression { +func getReader(r io.Reader, useLz4Compression bool) io.Reader { + if useLz4Compression { return lz4.NewReader(r) } return r @@ -237,7 +264,7 @@ func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) { defer res.Body.Close() - r := cu.compressibleBatch.getReader(res.Body) + r := getReader(res.Body, cu.compressibleBatch.useLz4Compression) records, err := getArrowRecords(r, cu.Start()) if err != nil { @@ -269,7 +296,7 @@ type localBatch struct { var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil) func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) { - r := lb.compressibleBatch.getReader(bytes.NewReader(lb.batchBytes)) + r := getReader(bytes.NewReader(lb.batchBytes), lb.compressibleBatch.useLz4Compression) r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r) records, err := getArrowRecords(r, lb.Start()) From aeb03a5437db81e1fa2230df94a6e0c3db96e1a8 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Tue, 9 Jul 2024 14:24:21 +0300 Subject: [PATCH 3/7] Refactor CloudFetch: process files sequentially via task queue Signed-off-by: Levko Kravets --- internal/fetcher/fetcher.go | 164 ------------ internal/fetcher/fetcher_test.go | 123 --------- internal/rows/arrowbased/batchloader.go | 335 +++++++++--------------- internal/rows/arrowbased/queue.go | 51 ++++ 4 files changed, 182 insertions(+), 491 deletions(-) delete mode 100644 internal/fetcher/fetcher.go delete mode 100644 internal/fetcher/fetcher_test.go create mode 100644 internal/rows/arrowbased/queue.go diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go deleted file mode 100644 index 8430ff0d..00000000 --- a/internal/fetcher/fetcher.go +++ /dev/null @@ -1,164 +0,0 @@ -package fetcher - -import ( - "context" - "sync" - - "github.com/databricks/databricks-sql-go/driverctx" - dbsqllog "github.com/databricks/databricks-sql-go/logger" -) - -type FetchableItems[OutputType any] interface { - Fetch(ctx context.Context) (OutputType, error) -} - -type Fetcher[OutputType any] interface { - Err() error - Start() (<-chan OutputType, context.CancelFunc, error) -} - -type concurrentFetcher[I FetchableItems[O], O any] struct { - cancelChan chan bool - inputChan <-chan FetchableItems[O] - outChan chan O - err error - nWorkers int - mu sync.Mutex - start sync.Once - ctx context.Context - cancelFunc context.CancelFunc - *dbsqllog.DBSQLLogger -} - -func (rf *concurrentFetcher[I, O]) Err() error { - rf.mu.Lock() - defer rf.mu.Unlock() - return rf.err -} - -func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) { - f.start.Do(func() { - // wait group for the worker routines - var wg sync.WaitGroup - - for i := 0; i < f.nWorkers; i++ { - - // increment wait group - wg.Add(1) - - f.logger().Trace().Msgf("concurrent fetcher starting worker %d", i) - go func(x int) { - // when work function remove one from the wait group - defer wg.Done() - // do the actual work - work(f, x) - f.logger().Trace().Msgf("concurrent fetcher worker %d done", x) - }(i) - - } - - // We want to close the output channel when all - // the workers are finished. This way the client won't - // be stuck waiting on the output channel. - go func() { - wg.Wait() - f.logger().Trace().Msg("concurrent fetcher closing output channel") - close(f.outChan) - }() - - // We return a cancel function so that the client can - // cancel fetching. - var cancelOnce sync.Once = sync.Once{} - f.cancelFunc = func() { - f.logger().Trace().Msg("concurrent fetcher cancel func") - cancelOnce.Do(func() { - f.logger().Trace().Msg("concurrent fetcher closing cancel channel") - close(f.cancelChan) - }) - } - }) - - return f.outChan, f.cancelFunc, nil -} - -func (f *concurrentFetcher[I, O]) setErr(err error) { - f.mu.Lock() - if f.err == nil { - f.err = err - } - f.mu.Unlock() -} - -func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger { - if f.DBSQLLogger == nil { - - f.DBSQLLogger = dbsqllog.WithContext(driverctx.ConnIdFromContext(f.ctx), driverctx.CorrelationIdFromContext(f.ctx), "") - - } - return f.DBSQLLogger -} - -func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) { - if nWorkers < 1 { - nWorkers = 1 - } - if maxItemsInMemory < 1 { - maxItemsInMemory = 1 - } - - // channel for loaded items - // TODO: pass buffer size - outputChannel := make(chan O, maxItemsInMemory) - - // channel to signal a cancel - stopChannel := make(chan bool) - - if ctx == nil { - ctx = context.Background() - } - - fetcher := &concurrentFetcher[I, O]{ - inputChan: inputChan, - outChan: outputChannel, - cancelChan: stopChannel, - ctx: ctx, - nWorkers: nWorkers, - } - - return fetcher, nil -} - -func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex int) { - - for { - select { - case <-f.cancelChan: - f.logger().Debug().Msgf("concurrent fetcher worker %d received cancel signal", workerIndex) - return - - case <-f.ctx.Done(): - f.logger().Debug().Msgf("concurrent fetcher worker %d context done", workerIndex) - return - - case input, ok := <-f.inputChan: - if ok { - f.logger().Trace().Msgf("concurrent fetcher worker %d loading item", workerIndex) - result, err := input.Fetch(f.ctx) - if err != nil { - f.logger().Trace().Msgf("concurrent fetcher worker %d received error", workerIndex) - f.setErr(err) - f.cancelFunc() - return - } else { - f.logger().Trace().Msgf("concurrent fetcher worker %d item loaded", workerIndex) - f.outChan <- result - } - } else { - f.logger().Trace().Msgf("concurrent fetcher ending %d", workerIndex) - return - } - - } - } - -} diff --git a/internal/fetcher/fetcher_test.go b/internal/fetcher/fetcher_test.go deleted file mode 100644 index dbe6ced0..00000000 --- a/internal/fetcher/fetcher_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package fetcher - -import ( - "context" - "math" - "testing" - "time" - - "github.com/pkg/errors" -) - -// Create a mock struct for FetchableItems -type mockFetchableItem struct { - item int - wait time.Duration -} - -type mockOutput struct { - item int -} - -// Implement the Fetch method -func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) { - time.Sleep(m.wait) - outputs := make([]*mockOutput, 5) - for i := range outputs { - sampleOutput := mockOutput{item: m.item} - outputs[i] = &sampleOutput - } - return outputs, nil -} - -var _ FetchableItems[[]*mockOutput] = (*mockFetchableItem)(nil) - -func TestConcurrentFetcher(t *testing.T) { - t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) { - ctx := context.Background() - - inputChan := make(chan FetchableItems[[]*mockOutput], 10) - for i := 0; i < 10; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} - inputChan <- &item - } - close(inputChan) - - // Create a fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan) - if err != nil { - t.Fatalf("Error creating fetcher: %v", err) - } - - start := time.Now() - outChan, _, err := fetcher.Start() - if err != nil { - t.Fatalf("Error starting fetcher: %v", err) - } - - var results []*mockOutput - for result := range outChan { - results = append(results, result...) - } - - // Check if the fetcher returned the expected results - expectedLen := 50 - if len(results) != expectedLen { - t.Errorf("Expected %d results, got %d", expectedLen, len(results)) - } - - // Check if the fetcher returned an error - if fetcher.Err() != nil { - t.Errorf("Fetcher returned an error: %v", fetcher.Err()) - } - - // Check if the fetcher took around the estimated amount of time - timeElapsed := time.Since(start) - rounds := int(math.Ceil(float64(10) / 3)) - expectedTime := time.Duration(rounds) * time.Second - buffer := 100 * time.Millisecond - if timeElapsed-expectedTime > buffer { - t.Errorf("Expected fetcher to take around %d ms, took %d ms", int64(expectedTime/time.Millisecond), int64(timeElapsed/time.Millisecond)) - } - }) - - t.Run("Cancel the concurrent fetcher", func(t *testing.T) { - // Create a context with a timeout - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - // Create an input channel - inputChan := make(chan FetchableItems[[]*mockOutput], 3) - for i := 0; i < 3; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} - inputChan <- &item - } - close(inputChan) - - // Create a new fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan) - if err != nil { - t.Fatalf("Error creating fetcher: %v", err) - } - - // Start the fetcher - outChan, cancelFunc, err := fetcher.Start() - if err != nil { - t.Fatal(err) - } - - // Ensure that the fetcher is cancelled successfully - go func() { - cancelFunc() - }() - - for range outChan { - // Just drain the channel - } - - // Check if an error occurred - if err := fetcher.Err(); err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("unexpected error: %v", err) - } - }) -} diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 36dc2022..82299243 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -17,7 +17,6 @@ import ( dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" - "github.com/databricks/databricks-sql-go/internal/fetcher" "github.com/databricks/databricks-sql-go/logger" ) @@ -27,26 +26,22 @@ type BatchIterator interface { Close() } -type BatchLoader interface { - rowscanner.Delimiter - GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) - Close() -} - func NewCloudBatchIterator( ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config, ) (BatchIterator, dbsqlerr.DBError) { - bl, err := newCloudBatchLoader(ctx, files, startRowOffset, cfg) - if err != nil { - return nil, err + bi := &cloudBatchIterator{ + ctx: ctx, + cfg: cfg, + startRowOffset: startRowOffset, + pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), + downloadTasks: NewQueue[cloudFetchDownloadTask](), } - bi := &batchIterator{ - nextBatchStart: bl.Start(), - batchLoader: bl, + for _, link := range files { + bi.pendingLinks.Enqueue(link) } return bi, nil @@ -60,22 +55,22 @@ func NewLocalBatchIterator( cfg *config.Config, ) (BatchIterator, dbsqlerr.DBError) { bi := &localBatchIterator{ - useLz4Compression: cfg.UseLz4Compression, - startRowOffset: startRowOffset, - arrowSchemaBytes: arrowSchemaBytes, - batches: batches, - index: -1, + cfg: cfg, + startRowOffset: startRowOffset, + arrowSchemaBytes: arrowSchemaBytes, + batches: batches, + index: -1, } return bi, nil } type localBatchIterator struct { - useLz4Compression bool - startRowOffset int64 - arrowSchemaBytes []byte - batches []*cli_service.TSparkArrowBatch - index int + cfg *config.Config + startRowOffset int64 + arrowSchemaBytes []byte + batches []*cli_service.TSparkArrowBatch + index int } var _ BatchIterator = (*localBatchIterator)(nil) @@ -88,7 +83,7 @@ func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { reader := io.MultiReader( bytes.NewReader(bi.arrowSchemaBytes), - getReader(bytes.NewReader(ab.Batch), bi.useLz4Compression), + getReader(bytes.NewReader(ab.Batch), bi.cfg.UseLz4Compression), ) records, err := getArrowRecords(reader, bi.startRowOffset) @@ -118,165 +113,157 @@ func (bi *localBatchIterator) Close() { bi.index = len(bi.batches) } -func newCloudBatchLoader( - ctx context.Context, - files []*cli_service.TSparkArrowResultLink, - startRowOffset int64, - cfg *config.Config, -) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +type cloudBatchIterator struct { + ctx context.Context + cfg *config.Config + startRowOffset int64 + pendingLinks Queue[cli_service.TSparkArrowResultLink] + downloadTasks Queue[cloudFetchDownloadTask] +} - if cfg == nil { - cfg = config.WithDefaults() - } +var _ BatchIterator = (*cloudBatchIterator)(nil) - inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(files)) - - var rowCount int64 - for i := range files { - f := files[i] - li := &cloudURL{ - // TSparkArrowResultLink: f, - Delimiter: rowscanner.NewDelimiter(f.StartRowOffset, f.RowCount), - fileLink: f.FileLink, - expiryTime: f.ExpiryTime, - minTimeToExpiry: cfg.MinTimeToExpiry, - compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, +func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { + for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) { + link := bi.pendingLinks.Dequeue() + logger.Debug().Msgf( + "CloudFetch: schedule link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + task := &cloudFetchDownloadTask{ + ctx: bi.ctx, + useLz4Compression: bi.cfg.UseLz4Compression, + link: link, + resultChan: make(chan SparkArrowBatch), + errorChan: make(chan error), + minTimeToExpiry: bi.cfg.MinTimeToExpiry, } - inputChan <- li - - rowCount += f.RowCount + task.Run() + bi.downloadTasks.Enqueue(task) } - // make sure to close input channel or fetcher will block waiting for more inputs - close(inputChan) - - f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) - cbl := &batchLoader[*cloudURL]{ - Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), - fetcher: f, - ctx: ctx, + task := bi.downloadTasks.Dequeue() + if task == nil { + return nil, io.EOF } - return cbl, nil + return task.GetResult() } -type batchLoader[T interface { - Fetch(ctx context.Context) (SparkArrowBatch, error) -}] struct { - rowscanner.Delimiter - fetcher fetcher.Fetcher[SparkArrowBatch] - arrowBatches []SparkArrowBatch - ctx context.Context +func (bi *cloudBatchIterator) HasNext() bool { + return (bi.pendingLinks.Len() > 0) || (bi.downloadTasks.Len() > 0) } -var _ BatchLoader = (*batchLoader[*localBatch])(nil) - -func (cbl *batchLoader[T]) GetBatchFor(rowNumber int64) (SparkArrowBatch, dbsqlerr.DBError) { +func (bi *cloudBatchIterator) Close() { + bi.pendingLinks.Clear() // Clear the list + // TODO: Cancel all download tasks + bi.downloadTasks.Clear() // Clear the list +} - logger.Debug().Msgf("batchLoader.GetBatchFor(%d)", rowNumber) +type cloudFetchDownloadTask struct { + ctx context.Context + useLz4Compression bool + minTimeToExpiry time.Duration + link *cli_service.TSparkArrowResultLink + resultChan chan SparkArrowBatch + errorChan chan error +} - for i := range cbl.arrowBatches { - logger.Debug().Msgf(" trying batch for range [%d..%d]", cbl.arrowBatches[i].Start(), cbl.arrowBatches[i].End()) - if cbl.arrowBatches[i].Contains(rowNumber) { - logger.Debug().Msgf(" found batch containing the requested row %d", rowNumber) - return cbl.arrowBatches[i], nil +func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { + link := cft.link + + select { + case batch, ok := <-cft.resultChan: + if ok { + logger.Debug().Msgf( + "CloudFetch: received data for link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return batch, nil + } + case err, ok := <-cft.errorChan: + if ok { + logger.Debug().Msgf( + "CloudFetch: failed to download link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return nil, err } } - logger.Debug().Msgf(" batch not found, trying to download more") - - batchChan, _, err := cbl.fetcher.Start() - var emptyBatch SparkArrowBatch - if err != nil { - logger.Debug().Msgf(" no batch found for row %d", rowNumber) - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) - } + logger.Debug().Msgf( + "CloudFetch: this should never happen; link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return nil, nil // TODO: ??? +} - for { - batch, ok := <-batchChan - if !ok { - err := cbl.fetcher.Err() - if err != nil { - logger.Debug().Msgf(" no batch found for row %d", rowNumber) - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) - } - break - } +func (cft *cloudFetchDownloadTask) Run() { + go func() { + link := cft.link - cbl.arrowBatches = append(cbl.arrowBatches, batch) - logger.Debug().Msgf(" trying newly downloaded batch for range [%d..%d]", batch.Start(), batch.End()) - if batch.Contains(rowNumber) { - logger.Debug().Msgf(" found batch containing the requested row %d", rowNumber) - return batch, nil + logger.Debug().Msgf( + "CloudFetch: start downloading link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + data, err := cft.fetchBatchBytes() + if err != nil { + cft.errorChan <- err + return } - } - logger.Debug().Msgf(" no batch found for row %d", rowNumber) + // TODO: error handling? + defer data.Close() - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) -} + reader := getReader(data, cft.useLz4Compression) -func (cbl *batchLoader[T]) Close() { - for i := range cbl.arrowBatches { - cbl.arrowBatches[i].Close() - } -} - -type compressibleBatch struct { - useLz4Compression bool -} - -func getReader(r io.Reader, useLz4Compression bool) io.Reader { - if useLz4Compression { - return lz4.NewReader(r) - } - return r -} + records, err := getArrowRecords(reader, cft.link.StartRowOffset) + if err != nil { + cft.errorChan <- err + return + } -type cloudURL struct { - compressibleBatch - rowscanner.Delimiter - fileLink string - expiryTime int64 - minTimeToExpiry time.Duration + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(cft.link.StartRowOffset, cft.link.RowCount), + arrowRecords: records, + } + cft.resultChan <- &batch + }() } -func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) { - var sab SparkArrowBatch - - if isLinkExpired(cu.expiryTime, cu.minTimeToExpiry) { - return sab, errors.New(dbsqlerr.ErrLinkExpired) +func (cft *cloudFetchDownloadTask) fetchBatchBytes() (io.ReadCloser, error) { + if isLinkExpired(cft.link.ExpiryTime, cft.minTimeToExpiry) { + return nil, errors.New(dbsqlerr.ErrLinkExpired) } - req, err := http.NewRequestWithContext(ctx, "GET", cu.fileLink, nil) + // TODO: Retry on HTTP errors + req, err := http.NewRequestWithContext(cft.ctx, "GET", cft.link.FileLink, nil) if err != nil { - return sab, err + return nil, err } client := http.DefaultClient res, err := client.Do(req) if err != nil { - return sab, err + return nil, err } if res.StatusCode != http.StatusOK { - return sab, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) + return nil, dbsqlerrint.NewDriverError(cft.ctx, errArrowRowsCloudFetchDownloadFailure, err) } - defer res.Body.Close() - - r := getReader(res.Body, cu.compressibleBatch.useLz4Compression) - - records, err := getArrowRecords(r, cu.Start()) - if err != nil { - return nil, err - } + return res.Body, nil +} - arrowBatch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(cu.Start(), cu.Count()), - arrowRecords: records, +func getReader(r io.Reader, useLz4Compression bool) io.Reader { + if useLz4Compression { + return lz4.NewReader(r) } - - return &arrowBatch, nil + return r } func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { @@ -284,35 +271,6 @@ func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { return expiryTime-bufferSecs < time.Now().Unix() } -var _ fetcher.FetchableItems[SparkArrowBatch] = (*cloudURL)(nil) - -type localBatch struct { - compressibleBatch - rowscanner.Delimiter - batchBytes []byte - arrowSchemaBytes []byte -} - -var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil) - -func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) { - r := getReader(bytes.NewReader(lb.batchBytes), lb.compressibleBatch.useLz4Compression) - r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r) - - records, err := getArrowRecords(r, lb.Start()) - if err != nil { - return &sparkArrowBatch{}, err - } - - lb.batchBytes = nil - batch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(lb.Start(), lb.Count()), - arrowRecords: records, - } - - return &batch, nil -} - func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, error) { ipcReader, err := ipc.NewReader(r) if err != nil { @@ -346,34 +304,3 @@ func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, err return records, nil } - -type batchIterator struct { - nextBatchStart int64 - batchLoader BatchLoader -} - -var _ BatchIterator = (*batchIterator)(nil) - -func (bi *batchIterator) Next() (SparkArrowBatch, error) { - if !bi.HasNext() { - return nil, io.EOF - } - if bi != nil && bi.batchLoader != nil { - batch, err := bi.batchLoader.GetBatchFor(bi.nextBatchStart) - if batch != nil && err == nil { - bi.nextBatchStart = batch.Start() + batch.Count() - } - return batch, err - } - return nil, nil -} - -func (bi *batchIterator) HasNext() bool { - return bi != nil && bi.batchLoader != nil && bi.batchLoader.Contains(bi.nextBatchStart) -} - -func (bi *batchIterator) Close() { - if bi != nil && bi.batchLoader != nil { - bi.batchLoader.Close() - } -} diff --git a/internal/rows/arrowbased/queue.go b/internal/rows/arrowbased/queue.go new file mode 100644 index 00000000..ed1d16f5 --- /dev/null +++ b/internal/rows/arrowbased/queue.go @@ -0,0 +1,51 @@ +package arrowbased + +import ( + "container/list" +) + +type Queue[ItemType any] interface { + Enqueue(item *ItemType) + Dequeue() *ItemType + Clear() + Len() int +} + +func NewQueue[ItemType any]() Queue[ItemType] { + return &queue[ItemType]{ + items: list.New(), + } +} + +type queue[ItemType any] struct { + items *list.List +} + +var _ Queue[any] = (*queue[any])(nil) + +func (q *queue[ItemType]) Enqueue(item *ItemType) { + q.items.PushBack(item) +} + +func (q *queue[ItemType]) Dequeue() *ItemType { + el := q.items.Front() + if el == nil { + return nil + } + q.items.Remove(el) + + value, ok := el.Value.(*ItemType) + if !ok { + return nil + } + + return value +} + +func (q *queue[ItemType]) Clear() { + q.items.Init() +} + +func (q *queue[ItemType]) Len() int { + return q.items.Len() +} From b65c005142df7ad300c55eaf7611a55ff74aa031 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Wed, 10 Jul 2024 11:50:11 +0300 Subject: [PATCH 4/7] Refine code Signed-off-by: Levko Kravets --- internal/rows/arrowbased/batchloader.go | 69 ++++++++++++++----------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 82299243..2f52efa9 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -135,8 +135,7 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { ctx: bi.ctx, useLz4Compression: bi.cfg.UseLz4Compression, link: link, - resultChan: make(chan SparkArrowBatch), - errorChan: make(chan error), + resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, } task.Run() @@ -161,70 +160,74 @@ func (bi *cloudBatchIterator) Close() { bi.downloadTasks.Clear() // Clear the list } +type cloudFetchDownloadTaskResult struct { + batch SparkArrowBatch + err error +} + type cloudFetchDownloadTask struct { ctx context.Context useLz4Compression bool minTimeToExpiry time.Duration link *cli_service.TSparkArrowResultLink - resultChan chan SparkArrowBatch - errorChan chan error + resultChan chan cloudFetchDownloadTaskResult } func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { link := cft.link - select { - case batch, ok := <-cft.resultChan: - if ok { - logger.Debug().Msgf( - "CloudFetch: received data for link at offset %d row count %d", - link.StartRowOffset, - link.RowCount, - ) - return batch, nil - } - case err, ok := <-cft.errorChan: - if ok { + result, ok := <-cft.resultChan + if ok { + if result.err != nil { logger.Debug().Msgf( "CloudFetch: failed to download link at offset %d row count %d", link.StartRowOffset, link.RowCount, ) - return nil, err + return nil, result.err } + logger.Debug().Msgf( + "CloudFetch: received data for link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return result.batch, nil } logger.Debug().Msgf( - "CloudFetch: this should never happen; link at offset %d row count %d", + "CloudFetch: channel was closed before result was received; link at offset %d row count %d", link.StartRowOffset, link.RowCount, ) - return nil, nil // TODO: ??? + return nil, nil // TODO: return error? } func (cft *cloudFetchDownloadTask) Run() { go func() { - link := cft.link - logger.Debug().Msgf( "CloudFetch: start downloading link at offset %d row count %d", - link.StartRowOffset, - link.RowCount, + cft.link.StartRowOffset, + cft.link.RowCount, ) - data, err := cft.fetchBatchBytes() + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry) if err != nil { - cft.errorChan <- err + cft.resultChan <- cloudFetchDownloadTaskResult{batch: nil, err: err} return } // TODO: error handling? defer data.Close() + logger.Debug().Msgf( + "CloudFetch: reading records for link at offset %d row count %d", + cft.link.StartRowOffset, + cft.link.RowCount, + ) reader := getReader(data, cft.useLz4Compression) records, err := getArrowRecords(reader, cft.link.StartRowOffset) if err != nil { - cft.errorChan <- err + cft.resultChan <- cloudFetchDownloadTaskResult{batch: nil, err: err} return } @@ -232,17 +235,21 @@ func (cft *cloudFetchDownloadTask) Run() { Delimiter: rowscanner.NewDelimiter(cft.link.StartRowOffset, cft.link.RowCount), arrowRecords: records, } - cft.resultChan <- &batch + cft.resultChan <- cloudFetchDownloadTaskResult{batch: &batch, err: nil} }() } -func (cft *cloudFetchDownloadTask) fetchBatchBytes() (io.ReadCloser, error) { - if isLinkExpired(cft.link.ExpiryTime, cft.minTimeToExpiry) { +func fetchBatchBytes( + ctx context.Context, + link *cli_service.TSparkArrowResultLink, + minTimeToExpiry time.Duration, +) (io.ReadCloser, error) { + if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) } // TODO: Retry on HTTP errors - req, err := http.NewRequestWithContext(cft.ctx, "GET", cft.link.FileLink, nil) + req, err := http.NewRequestWithContext(ctx, "GET", link.FileLink, nil) if err != nil { return nil, err } @@ -253,7 +260,7 @@ func (cft *cloudFetchDownloadTask) fetchBatchBytes() (io.ReadCloser, error) { return nil, err } if res.StatusCode != http.StatusOK { - return nil, dbsqlerrint.NewDriverError(cft.ctx, errArrowRowsCloudFetchDownloadFailure, err) + return nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) } return res.Body, nil From c629bae72083135865fa65f1be28e1b70df1e3a8 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Wed, 10 Jul 2024 11:51:13 +0300 Subject: [PATCH 5/7] Handle errors and cancel the remaining tasks Signed-off-by: Levko Kravets --- internal/rows/arrowbased/batchloader.go | 29 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 2f52efa9..41fb4597 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -131,8 +131,11 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { link.StartRowOffset, link.RowCount, ) + + cancelCtx, cancelFn := context.WithCancel(bi.ctx) task := &cloudFetchDownloadTask{ - ctx: bi.ctx, + ctx: cancelCtx, + cancel: cancelFn, useLz4Compression: bi.cfg.UseLz4Compression, link: link, resultChan: make(chan cloudFetchDownloadTaskResult), @@ -147,7 +150,17 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { return nil, io.EOF } - return task.GetResult() + batch, err := task.GetResult() + + // once we've got an errored out task - cancel the remaining ones + if err != nil { + bi.Close() + return nil, err + } + + // explicitly call cancel function on successfully completed task to avoid context leak + task.cancel() + return batch, nil } func (bi *cloudBatchIterator) HasNext() bool { @@ -155,9 +168,11 @@ func (bi *cloudBatchIterator) HasNext() bool { } func (bi *cloudBatchIterator) Close() { - bi.pendingLinks.Clear() // Clear the list - // TODO: Cancel all download tasks - bi.downloadTasks.Clear() // Clear the list + bi.pendingLinks.Clear() + for bi.downloadTasks.Len() > 0 { + task := bi.downloadTasks.Dequeue() + task.cancel() + } } type cloudFetchDownloadTaskResult struct { @@ -167,6 +182,7 @@ type cloudFetchDownloadTaskResult struct { type cloudFetchDownloadTask struct { ctx context.Context + cancel context.CancelFunc useLz4Compression bool minTimeToExpiry time.Duration link *cli_service.TSparkArrowResultLink @@ -180,9 +196,10 @@ func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { if ok { if result.err != nil { logger.Debug().Msgf( - "CloudFetch: failed to download link at offset %d row count %d", + "CloudFetch: failed to download link at offset %d row count %d, reason: %s", link.StartRowOffset, link.RowCount, + result.err.Error(), ) return nil, result.err } From 11fff499d987e39e6e9e9b8ba0fa876033fa1171 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Wed, 10 Jul 2024 12:06:24 +0300 Subject: [PATCH 6/7] Cleanup & refine code Signed-off-by: Levko Kravets --- internal/rows/arrowbased/batchloader.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 41fb4597..551dd07d 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -211,16 +211,19 @@ func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { return result.batch, nil } + // This branch should never be reached. If you see this message - something got really wrong logger.Debug().Msgf( "CloudFetch: channel was closed before result was received; link at offset %d row count %d", link.StartRowOffset, link.RowCount, ) - return nil, nil // TODO: return error? + return nil, nil } func (cft *cloudFetchDownloadTask) Run() { go func() { + defer close(cft.resultChan) + logger.Debug().Msgf( "CloudFetch: start downloading link at offset %d row count %d", cft.link.StartRowOffset, @@ -232,7 +235,7 @@ func (cft *cloudFetchDownloadTask) Run() { return } - // TODO: error handling? + // io.ReadCloser.Close() may return an error, but in this case it should be safe to ignore (I hope so) defer data.Close() logger.Debug().Msgf( From ec7811a25e4bd64a1d35faed3e0a0888b00b6fda Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 25 Jul 2024 16:51:21 +0300 Subject: [PATCH 7/7] Update tests Signed-off-by: Levko Kravets --- .../arrowbased/arrowRecordIterator_test.go | 36 ++- internal/rows/arrowbased/arrowRows_test.go | 177 +++++------ internal/rows/arrowbased/batchloader.go | 8 +- internal/rows/arrowbased/batchloader_test.go | 287 ++++++++++++------ 4 files changed, 314 insertions(+), 194 deletions(-) diff --git a/internal/rows/arrowbased/arrowRecordIterator_test.go b/internal/rows/arrowbased/arrowRecordIterator_test.go index a3e4040c..a3b67687 100644 --- a/internal/rows/arrowbased/arrowRecordIterator_test.go +++ b/internal/rows/arrowbased/arrowRecordIterator_test.go @@ -19,6 +19,8 @@ import ( func TestArrowRecordIterator(t *testing.T) { t.Run("with direct results", func(t *testing.T) { + logger := dbsqllog.WithContext("connectionId", "correlationId", "") + executeStatementResp := cli_service.TExecuteStatementResp{} loadTestData2(t, "directResultsMultipleFetch/ExecuteStatement.json", &executeStatementResp) @@ -30,32 +32,37 @@ func TestArrowRecordIterator(t *testing.T) { var fetchesInfo []fetchResultsInfo - client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) - logger := dbsqllog.WithContext("connectionId", "correlationId", "") + simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2}) rpi := rowscanner.NewResultPageIterator( rowscanner.NewDelimiter(0, 7311), 5000, nil, false, - client, + simpleClient, "connectionId", "correlationId", - logger) + logger, + ) - bl, err := NewLocalBatchLoader( + cfg := *config.WithDefaults() + + bi, err := NewLocalBatchIterator( context.Background(), executeStatementResp.DirectResults.ResultSet.Results.ArrowBatches, 0, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, - nil, + &cfg, ) - assert.Nil(t, err) - bi, err := NewBatchIterator(bl) assert.Nil(t, err) - cfg := *config.WithDefaults() - rs := NewArrowRecordIterator(context.Background(), rpi, bi, executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, cfg) + rs := NewArrowRecordIterator( + context.Background(), + rpi, + bi, + executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema, + cfg, + ) defer rs.Close() hasNext := rs.HasNext() @@ -108,6 +115,7 @@ func TestArrowRecordIterator(t *testing.T) { }) t.Run("no direct results", func(t *testing.T) { + logger := dbsqllog.WithContext("connectionId", "correlationId", "") fetchResp1 := cli_service.TFetchResultsResp{} loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1) @@ -120,17 +128,17 @@ func TestArrowRecordIterator(t *testing.T) { var fetchesInfo []fetchResultsInfo - client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) - logger := dbsqllog.WithContext("connectionId", "correlationId", "") + simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3}) rpi := rowscanner.NewResultPageIterator( rowscanner.NewDelimiter(0, 0), 5000, nil, false, - client, + simpleClient, "connectionId", "correlationId", - logger) + logger, + ) cfg := *config.WithDefaults() rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg) diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index c693ff56..c43674eb 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -469,7 +469,6 @@ func TestArrowRowScanner(t *testing.T) { }) t.Run("Create column value holders on first batch load", func(t *testing.T) { - rowSet := &cli_service.TRowSet{ ArrowBatches: []*cli_service.TSparkArrowBatch{ {RowCount: 5}, @@ -494,13 +493,13 @@ func TestArrowRowScanner(t *testing.T) { &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(2, 3), Record: &fakeRecord{}}}} b2 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}} b3 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}} - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{b1, b2, b3}, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi var callCount int ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { @@ -517,25 +516,25 @@ func TestArrowRowScanner(t *testing.T) { assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(1) assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(2) assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(5) assert.Nil(t, err) assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) - assert.Equal(t, 2, fbl.callCount) + assert.Equal(t, 2, fbi.callCount) }) @@ -557,25 +556,24 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi err := ars.loadBatchFor(0) assert.Nil(t, err) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) err = ars.loadBatchFor(0) assert.Nil(t, err) - assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, 1, fbi.callCount) }) t.Run("loadBatch index out of bounds", func(t *testing.T) { @@ -596,17 +594,16 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi err := ars.loadBatchFor(-1) assert.NotNil(t, err) @@ -636,17 +633,16 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi ars.valueContainerMaker = &fakeValueContainerMaker{ fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { @@ -657,7 +653,6 @@ func TestArrowRowScanner(t *testing.T) { err := ars.loadBatchFor(0) assert.NotNil(t, err) assert.ErrorContains(t, err, "error making containers") - }) t.Run("loadBatch record read failure", func(t *testing.T) { @@ -679,18 +674,17 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, - err: dbsqlerrint.NewDriverError(context.TODO(), "error reading record", nil), + index: -1, + callCount: 0, + err: dbsqlerrint.NewDriverError(context.TODO(), "error reading record", nil), } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi err := ars.loadBatchFor(0) assert.NotNil(t, err) @@ -716,40 +710,39 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi for _, i := range []int64{0, 1, 2, 3, 4} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, fbl.lastReadBatch) - assert.Equal(t, 1, fbl.callCount) - assert.Equal(t, int64(0), fbl.lastReadBatch.Start()) + assert.NotNil(t, fbi.lastReadBatch) + assert.Equal(t, 1, fbi.callCount) + assert.Equal(t, int64(0), fbi.lastReadBatch.Start()) } for _, i := range []int64{5, 6, 7} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, fbl.lastReadBatch) - assert.Equal(t, 2, fbl.callCount) - assert.Equal(t, int64(5), fbl.lastReadBatch.Start()) + assert.NotNil(t, fbi.lastReadBatch) + assert.Equal(t, 2, fbi.callCount) + assert.Equal(t, int64(5), fbi.lastReadBatch.Start()) } for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, fbl.lastReadBatch) - assert.Equal(t, 3, fbl.callCount) - assert.Equal(t, int64(8), fbl.lastReadBatch.Start()) + assert.NotNil(t, fbi.lastReadBatch) + assert.Equal(t, 3, fbi.callCount) + assert.Equal(t, int64(8), fbi.lastReadBatch.Start()) } err := ars.loadBatchFor(-1) @@ -869,17 +862,16 @@ func TestArrowRowScanner(t *testing.T) { ars.UseArrowNativeDecimal = true ars.UseArrowNativeIntervalTypes = true - fbl := &fakeBatchLoader{ - Delimiter: rowscanner.NewDelimiter(0, 15), + fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, }, + index: -1, + callCount: 0, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { columnValueHolders := make([]columnValues, len(ars.arrowSchema.Fields())) @@ -1049,17 +1041,13 @@ func TestArrowRowScanner(t *testing.T) { ars := d.(*arrowRowScanner) assert.Equal(t, int64(53940), ars.NRows()) - bi, ok := ars.batchIterator.(*batchIterator) + bi, ok := ars.batchIterator.(*localBatchIterator) assert.True(t, ok) - bl := bi.batchLoader - fbl := &batchLoaderWrapper{ - Delimiter: rowscanner.NewDelimiter(bl.Start(), bl.Count()), - bl: bl, + fbi := &batchIteratorWrapper{ + bi: bi, } - var e dbsqlerr.DBError - ars.batchIterator, e = NewBatchIterator(fbl) - assert.Nil(t, e) + ars.batchIterator = fbi dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) for i := int64(0); i < ars.NRows(); i = i + 1 { @@ -1079,7 +1067,7 @@ func TestArrowRowScanner(t *testing.T) { } } - assert.Equal(t, 54, fbl.callCount) + assert.Equal(t, 54, fbi.callCount) }) t.Run("Retrieve values - native arrow schema", func(t *testing.T) { @@ -1647,49 +1635,68 @@ func (cv *fakeColumnValues) SetValueArray(colData arrow.ArrayData) error { return nil } -type fakeBatchLoader struct { - rowscanner.Delimiter +type fakeBatchIterator struct { batches []SparkArrowBatch + index int callCount int err dbsqlerr.DBError lastReadBatch SparkArrowBatch } -var _ BatchLoader = (*fakeBatchLoader)(nil) +var _ BatchIterator = (*fakeBatchIterator)(nil) + +func (fbi *fakeBatchIterator) Next() (SparkArrowBatch, error) { + fbi.callCount += 1 -func (fbl *fakeBatchLoader) Close() {} -func (fbl *fakeBatchLoader) GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) { - fbl.callCount += 1 - if fbl.err != nil { - return nil, fbl.err + if fbi.err != nil { + return nil, fbi.err } - for i := range fbl.batches { - if fbl.batches[i].Contains(recordNum) { - fbl.lastReadBatch = fbl.batches[i] - return fbl.batches[i], nil - } + cnt := len(fbi.batches) + fbi.index++ + if fbi.index < cnt { + fbi.lastReadBatch = fbi.batches[fbi.index] + return fbi.lastReadBatch, nil } - return nil, dbsqlerrint.NewDriverError(context.Background(), errArrowRowsInvalidRowNumber(recordNum), nil) + + fbi.lastReadBatch = nil + return nil, io.EOF +} + +func (fbi *fakeBatchIterator) HasNext() bool { + // `Next()` will first increment an index, and only then return a batch + // So `HasNext` should check if index can be incremented and still be within array + return fbi.index+1 < len(fbi.batches) +} + +func (fbi *fakeBatchIterator) Close() { + fbi.index = len(fbi.batches) + fbi.lastReadBatch = nil } -type batchLoaderWrapper struct { - rowscanner.Delimiter - bl BatchLoader +type batchIteratorWrapper struct { + bi BatchIterator callCount int lastLoadedBatch SparkArrowBatch } -var _ BatchLoader = (*batchLoaderWrapper)(nil) +var _ BatchIterator = (*batchIteratorWrapper)(nil) -func (fbl *batchLoaderWrapper) Close() { fbl.bl.Close() } -func (fbl *batchLoaderWrapper) GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) { - fbl.callCount += 1 - batch, err := fbl.bl.GetBatchFor(recordNum) - fbl.lastLoadedBatch = batch +func (biw *batchIteratorWrapper) Next() (SparkArrowBatch, error) { + biw.callCount += 1 + batch, err := biw.bi.Next() + biw.lastLoadedBatch = batch return batch, err } +func (biw *batchIteratorWrapper) HasNext() bool { + return biw.bi.HasNext() +} + +func (biw *batchIteratorWrapper) Close() { + biw.bi.Close() +} + type fakeRecord struct { fnRelease func() fnRetain func() diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 551dd07d..45b067dd 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -3,6 +3,7 @@ package arrowbased import ( "bytes" "context" + "fmt" "io" "time" @@ -106,7 +107,9 @@ func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { } func (bi *localBatchIterator) HasNext() bool { - return bi.index < len(bi.batches) + // `Next()` will first increment an index, and only then return a batch + // So `HasNext` should check if index can be incremented and still be within array + return bi.index+1 < len(bi.batches) } func (bi *localBatchIterator) Close() { @@ -280,7 +283,8 @@ func fetchBatchBytes( return nil, err } if res.StatusCode != http.StatusOK { - return nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) + msg := fmt.Sprintf("%s: %s %d", errArrowRowsCloudFetchDownloadFailure, "HTTP error", res.StatusCode) + return nil, dbsqlerrint.NewDriverError(ctx, msg, err) } return res.Body, nil diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 35bad337..e47eef08 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -4,9 +4,11 @@ import ( "bytes" "context" "fmt" + dbsqlerr "github.com/databricks/databricks-sql-go/errors" + "github.com/databricks/databricks-sql-go/internal/cli_service" + "github.com/databricks/databricks-sql-go/internal/config" "net/http" "net/http/httptest" - "reflect" "testing" "time" @@ -14,117 +16,216 @@ import ( "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" "github.com/apache/arrow/go/v12/arrow/memory" - dbsqlerr "github.com/databricks/databricks-sql-go/errors" - dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" - "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) -func TestCloudURLFetch(t *testing.T) { +func TestCloudFetchIterator(t *testing.T) { var handler func(w http.ResponseWriter, r *http.Request) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler(w, r) })) defer server.Close() - testTable := []struct { - name string - response func(w http.ResponseWriter, r *http.Request) - linkExpired bool - expectedResponse SparkArrowBatch - expectedErr error - }{ - { - name: "cloud-fetch-happy-case", - response: func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) - if err != nil { - panic(err) - } + + t.Run("should fetch all the links", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, }, - linkExpired: false, - expectedResponse: &sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(0, 3), - arrowRecords: []SparkArrowRecord{ - &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 3), Record: generateArrowRecord()}, - &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(3, 3), Record: generateArrowRecord()}, - }, + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset + 1, + RowCount: 1, }, - expectedErr: nil, - }, - { - name: "cloud-fetch-expired_link", - response: func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) - if err != nil { - panic(err) - } + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + if err != nil { + panic(err) + } + + cbi := bi.(*cloudBatchIterator) + + assert.True(t, bi.HasNext()) + assert.Equal(t, cbi.pendingLinks.Len(), len(links)) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + + // get first link - should succeed + sab1, err2 := bi.Next() + if err2 != nil { + panic(err2) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab1.Start(), startRowOffset) + + // get second link - should succeed + sab2, err3 := bi.Next() + if err3 != nil { + panic(err3) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-2) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab2.Start(), startRowOffset+sab1.Count()) + + // all links downloaded, should be no more data + assert.False(t, bi.HasNext()) + }) + + t.Run("should fail on expired link", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, }, - linkExpired: true, - expectedResponse: nil, - expectedErr: errors.New(dbsqlerr.ErrLinkExpired), - }, - { - name: "cloud-fetch-http-error", - response: func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(-10 * time.Minute).Unix(), // expired link + StartRowOffset: startRowOffset + 1, + RowCount: 1, }, - linkExpired: false, - expectedResponse: nil, - expectedErr: dbsqlerrint.NewDriverError(context.TODO(), errArrowRowsCloudFetchDownloadFailure, nil), - }, - } + } - for _, tc := range testTable { - t.Run(tc.name, func(t *testing.T) { - handler = tc.response + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 - expiryTime := time.Now() - // If link expired, subtract 1 sec from current time to get expiration time - if tc.linkExpired { - expiryTime = expiryTime.Add(-1 * time.Second) - } else { - expiryTime = expiryTime.Add(10 * time.Second) - } + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + if err != nil { + panic(err) + } - cu := &cloudURL{ - Delimiter: rowscanner.NewDelimiter(0, 3), - fileLink: server.URL, - expiryTime: expiryTime.Unix(), - } + cbi := bi.(*cloudBatchIterator) - ctx := context.Background() - - resp, err := cu.Fetch(ctx) - - if tc.expectedResponse != nil { - assert.NotNil(t, resp) - esab, ok := tc.expectedResponse.(*sparkArrowBatch) - assert.True(t, ok) - asab, ok2 := resp.(*sparkArrowBatch) - assert.True(t, ok2) - if !reflect.DeepEqual(esab.Delimiter, asab.Delimiter) { - t.Errorf("expected (%v), got (%v)", esab.Delimiter, asab.Delimiter) - } - assert.Equal(t, len(esab.arrowRecords), len(asab.arrowRecords)) - for i := range esab.arrowRecords { - er := esab.arrowRecords[i] - ar := asab.arrowRecords[i] - - eb := generateMockArrowBytes(er) - ab := generateMockArrowBytes(ar) - assert.Equal(t, eb, ab) - } - } + assert.True(t, bi.HasNext()) + assert.Equal(t, cbi.pendingLinks.Len(), len(links)) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + + // get first link - should succeed + sab1, err2 := bi.Next() + if err2 != nil { + panic(err2) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab1.Start(), startRowOffset) + + // get second link - should fail + _, err3 := bi.Next() + assert.NotNil(t, err3) + assert.ErrorContains(t, err3, dbsqlerr.ErrLinkExpired) + }) + + t.Run("should fail on HTTP errors", func(t *testing.T) { + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset + 1, + RowCount: 1, + }, + } - if !errors.Is(err, tc.expectedErr) { - assert.EqualErrorf(t, err, fmt.Sprintf("%v", tc.expectedErr), "expected (%v), got (%v)", tc.expectedErr, err) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + if err != nil { + panic(err) + } + + cbi := bi.(*cloudBatchIterator) + + assert.True(t, bi.HasNext()) + assert.Equal(t, cbi.pendingLinks.Len(), len(links)) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + + // set handler for the first link, which returns some data + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) } - }) - } + } + + // get first link - should succeed + sab1, err2 := bi.Next() + if err2 != nil { + panic(err2) + } + + assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) + assert.Equal(t, cbi.downloadTasks.Len(), 0) + assert.Equal(t, sab1.Start(), startRowOffset) + + // set handler for the first link, which fails with some non-retryable HTTP error + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + } + + // get second link - should fail + _, err3 := bi.Next() + assert.NotNil(t, err3) + assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) + }) } func generateArrowRecord() arrow.Record {