From aeb03a5437db81e1fa2230df94a6e0c3db96e1a8 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Tue, 9 Jul 2024 14:24:21 +0300 Subject: [PATCH] Refactor CloudFetch: process files sequentially via task queue Signed-off-by: Levko Kravets --- internal/fetcher/fetcher.go | 164 ------------ internal/fetcher/fetcher_test.go | 123 --------- internal/rows/arrowbased/batchloader.go | 335 +++++++++--------------- internal/rows/arrowbased/queue.go | 51 ++++ 4 files changed, 182 insertions(+), 491 deletions(-) delete mode 100644 internal/fetcher/fetcher.go delete mode 100644 internal/fetcher/fetcher_test.go create mode 100644 internal/rows/arrowbased/queue.go diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go deleted file mode 100644 index 8430ff0..0000000 --- a/internal/fetcher/fetcher.go +++ /dev/null @@ -1,164 +0,0 @@ -package fetcher - -import ( - "context" - "sync" - - "github.com/databricks/databricks-sql-go/driverctx" - dbsqllog "github.com/databricks/databricks-sql-go/logger" -) - -type FetchableItems[OutputType any] interface { - Fetch(ctx context.Context) (OutputType, error) -} - -type Fetcher[OutputType any] interface { - Err() error - Start() (<-chan OutputType, context.CancelFunc, error) -} - -type concurrentFetcher[I FetchableItems[O], O any] struct { - cancelChan chan bool - inputChan <-chan FetchableItems[O] - outChan chan O - err error - nWorkers int - mu sync.Mutex - start sync.Once - ctx context.Context - cancelFunc context.CancelFunc - *dbsqllog.DBSQLLogger -} - -func (rf *concurrentFetcher[I, O]) Err() error { - rf.mu.Lock() - defer rf.mu.Unlock() - return rf.err -} - -func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) { - f.start.Do(func() { - // wait group for the worker routines - var wg sync.WaitGroup - - for i := 0; i < f.nWorkers; i++ { - - // increment wait group - wg.Add(1) - - f.logger().Trace().Msgf("concurrent fetcher starting worker %d", i) - go func(x int) { - // when work function remove one from the wait group - defer wg.Done() - // do the actual work - work(f, x) - f.logger().Trace().Msgf("concurrent fetcher worker %d done", x) - }(i) - - } - - // We want to close the output channel when all - // the workers are finished. This way the client won't - // be stuck waiting on the output channel. - go func() { - wg.Wait() - f.logger().Trace().Msg("concurrent fetcher closing output channel") - close(f.outChan) - }() - - // We return a cancel function so that the client can - // cancel fetching. - var cancelOnce sync.Once = sync.Once{} - f.cancelFunc = func() { - f.logger().Trace().Msg("concurrent fetcher cancel func") - cancelOnce.Do(func() { - f.logger().Trace().Msg("concurrent fetcher closing cancel channel") - close(f.cancelChan) - }) - } - }) - - return f.outChan, f.cancelFunc, nil -} - -func (f *concurrentFetcher[I, O]) setErr(err error) { - f.mu.Lock() - if f.err == nil { - f.err = err - } - f.mu.Unlock() -} - -func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger { - if f.DBSQLLogger == nil { - - f.DBSQLLogger = dbsqllog.WithContext(driverctx.ConnIdFromContext(f.ctx), driverctx.CorrelationIdFromContext(f.ctx), "") - - } - return f.DBSQLLogger -} - -func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) { - if nWorkers < 1 { - nWorkers = 1 - } - if maxItemsInMemory < 1 { - maxItemsInMemory = 1 - } - - // channel for loaded items - // TODO: pass buffer size - outputChannel := make(chan O, maxItemsInMemory) - - // channel to signal a cancel - stopChannel := make(chan bool) - - if ctx == nil { - ctx = context.Background() - } - - fetcher := &concurrentFetcher[I, O]{ - inputChan: inputChan, - outChan: outputChannel, - cancelChan: stopChannel, - ctx: ctx, - nWorkers: nWorkers, - } - - return fetcher, nil -} - -func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex int) { - - for { - select { - case <-f.cancelChan: - f.logger().Debug().Msgf("concurrent fetcher worker %d received cancel signal", workerIndex) - return - - case <-f.ctx.Done(): - f.logger().Debug().Msgf("concurrent fetcher worker %d context done", workerIndex) - return - - case input, ok := <-f.inputChan: - if ok { - f.logger().Trace().Msgf("concurrent fetcher worker %d loading item", workerIndex) - result, err := input.Fetch(f.ctx) - if err != nil { - f.logger().Trace().Msgf("concurrent fetcher worker %d received error", workerIndex) - f.setErr(err) - f.cancelFunc() - return - } else { - f.logger().Trace().Msgf("concurrent fetcher worker %d item loaded", workerIndex) - f.outChan <- result - } - } else { - f.logger().Trace().Msgf("concurrent fetcher ending %d", workerIndex) - return - } - - } - } - -} diff --git a/internal/fetcher/fetcher_test.go b/internal/fetcher/fetcher_test.go deleted file mode 100644 index dbe6ced..0000000 --- a/internal/fetcher/fetcher_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package fetcher - -import ( - "context" - "math" - "testing" - "time" - - "github.com/pkg/errors" -) - -// Create a mock struct for FetchableItems -type mockFetchableItem struct { - item int - wait time.Duration -} - -type mockOutput struct { - item int -} - -// Implement the Fetch method -func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) { - time.Sleep(m.wait) - outputs := make([]*mockOutput, 5) - for i := range outputs { - sampleOutput := mockOutput{item: m.item} - outputs[i] = &sampleOutput - } - return outputs, nil -} - -var _ FetchableItems[[]*mockOutput] = (*mockFetchableItem)(nil) - -func TestConcurrentFetcher(t *testing.T) { - t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) { - ctx := context.Background() - - inputChan := make(chan FetchableItems[[]*mockOutput], 10) - for i := 0; i < 10; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} - inputChan <- &item - } - close(inputChan) - - // Create a fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan) - if err != nil { - t.Fatalf("Error creating fetcher: %v", err) - } - - start := time.Now() - outChan, _, err := fetcher.Start() - if err != nil { - t.Fatalf("Error starting fetcher: %v", err) - } - - var results []*mockOutput - for result := range outChan { - results = append(results, result...) - } - - // Check if the fetcher returned the expected results - expectedLen := 50 - if len(results) != expectedLen { - t.Errorf("Expected %d results, got %d", expectedLen, len(results)) - } - - // Check if the fetcher returned an error - if fetcher.Err() != nil { - t.Errorf("Fetcher returned an error: %v", fetcher.Err()) - } - - // Check if the fetcher took around the estimated amount of time - timeElapsed := time.Since(start) - rounds := int(math.Ceil(float64(10) / 3)) - expectedTime := time.Duration(rounds) * time.Second - buffer := 100 * time.Millisecond - if timeElapsed-expectedTime > buffer { - t.Errorf("Expected fetcher to take around %d ms, took %d ms", int64(expectedTime/time.Millisecond), int64(timeElapsed/time.Millisecond)) - } - }) - - t.Run("Cancel the concurrent fetcher", func(t *testing.T) { - // Create a context with a timeout - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - // Create an input channel - inputChan := make(chan FetchableItems[[]*mockOutput], 3) - for i := 0; i < 3; i++ { - item := mockFetchableItem{item: i, wait: 1 * time.Second} - inputChan <- &item - } - close(inputChan) - - // Create a new fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan) - if err != nil { - t.Fatalf("Error creating fetcher: %v", err) - } - - // Start the fetcher - outChan, cancelFunc, err := fetcher.Start() - if err != nil { - t.Fatal(err) - } - - // Ensure that the fetcher is cancelled successfully - go func() { - cancelFunc() - }() - - for range outChan { - // Just drain the channel - } - - // Check if an error occurred - if err := fetcher.Err(); err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("unexpected error: %v", err) - } - }) -} diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 36dc202..8229924 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -17,7 +17,6 @@ import ( dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" - "github.com/databricks/databricks-sql-go/internal/fetcher" "github.com/databricks/databricks-sql-go/logger" ) @@ -27,26 +26,22 @@ type BatchIterator interface { Close() } -type BatchLoader interface { - rowscanner.Delimiter - GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) - Close() -} - func NewCloudBatchIterator( ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config, ) (BatchIterator, dbsqlerr.DBError) { - bl, err := newCloudBatchLoader(ctx, files, startRowOffset, cfg) - if err != nil { - return nil, err + bi := &cloudBatchIterator{ + ctx: ctx, + cfg: cfg, + startRowOffset: startRowOffset, + pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), + downloadTasks: NewQueue[cloudFetchDownloadTask](), } - bi := &batchIterator{ - nextBatchStart: bl.Start(), - batchLoader: bl, + for _, link := range files { + bi.pendingLinks.Enqueue(link) } return bi, nil @@ -60,22 +55,22 @@ func NewLocalBatchIterator( cfg *config.Config, ) (BatchIterator, dbsqlerr.DBError) { bi := &localBatchIterator{ - useLz4Compression: cfg.UseLz4Compression, - startRowOffset: startRowOffset, - arrowSchemaBytes: arrowSchemaBytes, - batches: batches, - index: -1, + cfg: cfg, + startRowOffset: startRowOffset, + arrowSchemaBytes: arrowSchemaBytes, + batches: batches, + index: -1, } return bi, nil } type localBatchIterator struct { - useLz4Compression bool - startRowOffset int64 - arrowSchemaBytes []byte - batches []*cli_service.TSparkArrowBatch - index int + cfg *config.Config + startRowOffset int64 + arrowSchemaBytes []byte + batches []*cli_service.TSparkArrowBatch + index int } var _ BatchIterator = (*localBatchIterator)(nil) @@ -88,7 +83,7 @@ func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { reader := io.MultiReader( bytes.NewReader(bi.arrowSchemaBytes), - getReader(bytes.NewReader(ab.Batch), bi.useLz4Compression), + getReader(bytes.NewReader(ab.Batch), bi.cfg.UseLz4Compression), ) records, err := getArrowRecords(reader, bi.startRowOffset) @@ -118,165 +113,157 @@ func (bi *localBatchIterator) Close() { bi.index = len(bi.batches) } -func newCloudBatchLoader( - ctx context.Context, - files []*cli_service.TSparkArrowResultLink, - startRowOffset int64, - cfg *config.Config, -) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +type cloudBatchIterator struct { + ctx context.Context + cfg *config.Config + startRowOffset int64 + pendingLinks Queue[cli_service.TSparkArrowResultLink] + downloadTasks Queue[cloudFetchDownloadTask] +} - if cfg == nil { - cfg = config.WithDefaults() - } +var _ BatchIterator = (*cloudBatchIterator)(nil) - inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(files)) - - var rowCount int64 - for i := range files { - f := files[i] - li := &cloudURL{ - // TSparkArrowResultLink: f, - Delimiter: rowscanner.NewDelimiter(f.StartRowOffset, f.RowCount), - fileLink: f.FileLink, - expiryTime: f.ExpiryTime, - minTimeToExpiry: cfg.MinTimeToExpiry, - compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, +func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { + for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) { + link := bi.pendingLinks.Dequeue() + logger.Debug().Msgf( + "CloudFetch: schedule link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + task := &cloudFetchDownloadTask{ + ctx: bi.ctx, + useLz4Compression: bi.cfg.UseLz4Compression, + link: link, + resultChan: make(chan SparkArrowBatch), + errorChan: make(chan error), + minTimeToExpiry: bi.cfg.MinTimeToExpiry, } - inputChan <- li - - rowCount += f.RowCount + task.Run() + bi.downloadTasks.Enqueue(task) } - // make sure to close input channel or fetcher will block waiting for more inputs - close(inputChan) - - f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) - cbl := &batchLoader[*cloudURL]{ - Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), - fetcher: f, - ctx: ctx, + task := bi.downloadTasks.Dequeue() + if task == nil { + return nil, io.EOF } - return cbl, nil + return task.GetResult() } -type batchLoader[T interface { - Fetch(ctx context.Context) (SparkArrowBatch, error) -}] struct { - rowscanner.Delimiter - fetcher fetcher.Fetcher[SparkArrowBatch] - arrowBatches []SparkArrowBatch - ctx context.Context +func (bi *cloudBatchIterator) HasNext() bool { + return (bi.pendingLinks.Len() > 0) || (bi.downloadTasks.Len() > 0) } -var _ BatchLoader = (*batchLoader[*localBatch])(nil) - -func (cbl *batchLoader[T]) GetBatchFor(rowNumber int64) (SparkArrowBatch, dbsqlerr.DBError) { +func (bi *cloudBatchIterator) Close() { + bi.pendingLinks.Clear() // Clear the list + // TODO: Cancel all download tasks + bi.downloadTasks.Clear() // Clear the list +} - logger.Debug().Msgf("batchLoader.GetBatchFor(%d)", rowNumber) +type cloudFetchDownloadTask struct { + ctx context.Context + useLz4Compression bool + minTimeToExpiry time.Duration + link *cli_service.TSparkArrowResultLink + resultChan chan SparkArrowBatch + errorChan chan error +} - for i := range cbl.arrowBatches { - logger.Debug().Msgf(" trying batch for range [%d..%d]", cbl.arrowBatches[i].Start(), cbl.arrowBatches[i].End()) - if cbl.arrowBatches[i].Contains(rowNumber) { - logger.Debug().Msgf(" found batch containing the requested row %d", rowNumber) - return cbl.arrowBatches[i], nil +func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { + link := cft.link + + select { + case batch, ok := <-cft.resultChan: + if ok { + logger.Debug().Msgf( + "CloudFetch: received data for link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return batch, nil + } + case err, ok := <-cft.errorChan: + if ok { + logger.Debug().Msgf( + "CloudFetch: failed to download link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return nil, err } } - logger.Debug().Msgf(" batch not found, trying to download more") - - batchChan, _, err := cbl.fetcher.Start() - var emptyBatch SparkArrowBatch - if err != nil { - logger.Debug().Msgf(" no batch found for row %d", rowNumber) - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) - } + logger.Debug().Msgf( + "CloudFetch: this should never happen; link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + return nil, nil // TODO: ??? +} - for { - batch, ok := <-batchChan - if !ok { - err := cbl.fetcher.Err() - if err != nil { - logger.Debug().Msgf(" no batch found for row %d", rowNumber) - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) - } - break - } +func (cft *cloudFetchDownloadTask) Run() { + go func() { + link := cft.link - cbl.arrowBatches = append(cbl.arrowBatches, batch) - logger.Debug().Msgf(" trying newly downloaded batch for range [%d..%d]", batch.Start(), batch.End()) - if batch.Contains(rowNumber) { - logger.Debug().Msgf(" found batch containing the requested row %d", rowNumber) - return batch, nil + logger.Debug().Msgf( + "CloudFetch: start downloading link at offset %d row count %d", + link.StartRowOffset, + link.RowCount, + ) + data, err := cft.fetchBatchBytes() + if err != nil { + cft.errorChan <- err + return } - } - logger.Debug().Msgf(" no batch found for row %d", rowNumber) + // TODO: error handling? + defer data.Close() - return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) -} + reader := getReader(data, cft.useLz4Compression) -func (cbl *batchLoader[T]) Close() { - for i := range cbl.arrowBatches { - cbl.arrowBatches[i].Close() - } -} - -type compressibleBatch struct { - useLz4Compression bool -} - -func getReader(r io.Reader, useLz4Compression bool) io.Reader { - if useLz4Compression { - return lz4.NewReader(r) - } - return r -} + records, err := getArrowRecords(reader, cft.link.StartRowOffset) + if err != nil { + cft.errorChan <- err + return + } -type cloudURL struct { - compressibleBatch - rowscanner.Delimiter - fileLink string - expiryTime int64 - minTimeToExpiry time.Duration + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(cft.link.StartRowOffset, cft.link.RowCount), + arrowRecords: records, + } + cft.resultChan <- &batch + }() } -func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) { - var sab SparkArrowBatch - - if isLinkExpired(cu.expiryTime, cu.minTimeToExpiry) { - return sab, errors.New(dbsqlerr.ErrLinkExpired) +func (cft *cloudFetchDownloadTask) fetchBatchBytes() (io.ReadCloser, error) { + if isLinkExpired(cft.link.ExpiryTime, cft.minTimeToExpiry) { + return nil, errors.New(dbsqlerr.ErrLinkExpired) } - req, err := http.NewRequestWithContext(ctx, "GET", cu.fileLink, nil) + // TODO: Retry on HTTP errors + req, err := http.NewRequestWithContext(cft.ctx, "GET", cft.link.FileLink, nil) if err != nil { - return sab, err + return nil, err } client := http.DefaultClient res, err := client.Do(req) if err != nil { - return sab, err + return nil, err } if res.StatusCode != http.StatusOK { - return sab, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) + return nil, dbsqlerrint.NewDriverError(cft.ctx, errArrowRowsCloudFetchDownloadFailure, err) } - defer res.Body.Close() - - r := getReader(res.Body, cu.compressibleBatch.useLz4Compression) - - records, err := getArrowRecords(r, cu.Start()) - if err != nil { - return nil, err - } + return res.Body, nil +} - arrowBatch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(cu.Start(), cu.Count()), - arrowRecords: records, +func getReader(r io.Reader, useLz4Compression bool) io.Reader { + if useLz4Compression { + return lz4.NewReader(r) } - - return &arrowBatch, nil + return r } func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { @@ -284,35 +271,6 @@ func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { return expiryTime-bufferSecs < time.Now().Unix() } -var _ fetcher.FetchableItems[SparkArrowBatch] = (*cloudURL)(nil) - -type localBatch struct { - compressibleBatch - rowscanner.Delimiter - batchBytes []byte - arrowSchemaBytes []byte -} - -var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil) - -func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) { - r := getReader(bytes.NewReader(lb.batchBytes), lb.compressibleBatch.useLz4Compression) - r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r) - - records, err := getArrowRecords(r, lb.Start()) - if err != nil { - return &sparkArrowBatch{}, err - } - - lb.batchBytes = nil - batch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(lb.Start(), lb.Count()), - arrowRecords: records, - } - - return &batch, nil -} - func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, error) { ipcReader, err := ipc.NewReader(r) if err != nil { @@ -346,34 +304,3 @@ func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, err return records, nil } - -type batchIterator struct { - nextBatchStart int64 - batchLoader BatchLoader -} - -var _ BatchIterator = (*batchIterator)(nil) - -func (bi *batchIterator) Next() (SparkArrowBatch, error) { - if !bi.HasNext() { - return nil, io.EOF - } - if bi != nil && bi.batchLoader != nil { - batch, err := bi.batchLoader.GetBatchFor(bi.nextBatchStart) - if batch != nil && err == nil { - bi.nextBatchStart = batch.Start() + batch.Count() - } - return batch, err - } - return nil, nil -} - -func (bi *batchIterator) HasNext() bool { - return bi != nil && bi.batchLoader != nil && bi.batchLoader.Contains(bi.nextBatchStart) -} - -func (bi *batchIterator) Close() { - if bi != nil && bi.batchLoader != nil { - bi.batchLoader.Close() - } -} diff --git a/internal/rows/arrowbased/queue.go b/internal/rows/arrowbased/queue.go new file mode 100644 index 0000000..ed1d16f --- /dev/null +++ b/internal/rows/arrowbased/queue.go @@ -0,0 +1,51 @@ +package arrowbased + +import ( + "container/list" +) + +type Queue[ItemType any] interface { + Enqueue(item *ItemType) + Dequeue() *ItemType + Clear() + Len() int +} + +func NewQueue[ItemType any]() Queue[ItemType] { + return &queue[ItemType]{ + items: list.New(), + } +} + +type queue[ItemType any] struct { + items *list.List +} + +var _ Queue[any] = (*queue[any])(nil) + +func (q *queue[ItemType]) Enqueue(item *ItemType) { + q.items.PushBack(item) +} + +func (q *queue[ItemType]) Dequeue() *ItemType { + el := q.items.Front() + if el == nil { + return nil + } + q.items.Remove(el) + + value, ok := el.Value.(*ItemType) + if !ok { + return nil + } + + return value +} + +func (q *queue[ItemType]) Clear() { + q.items.Init() +} + +func (q *queue[ItemType]) Len() int { + return q.items.Len() +}