diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 85c33b5..36dc202 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -33,7 +33,12 @@ type BatchLoader interface { Close() } -func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) { +func NewCloudBatchIterator( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { bl, err := newCloudBatchLoader(ctx, files, startRowOffset, cfg) if err != nil { return nil, err @@ -47,21 +52,78 @@ func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrow return bi, nil } -func NewLocalBatchIterator(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) { - bl, err := newLocalBatchLoader(ctx, batches, startRowOffset, arrowSchemaBytes, cfg) - if err != nil { - return nil, err +func NewLocalBatchIterator( + ctx context.Context, + batches []*cli_service.TSparkArrowBatch, + startRowOffset int64, + arrowSchemaBytes []byte, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { + bi := &localBatchIterator{ + useLz4Compression: cfg.UseLz4Compression, + startRowOffset: startRowOffset, + arrowSchemaBytes: arrowSchemaBytes, + batches: batches, + index: -1, } - bi := &batchIterator{ - nextBatchStart: bl.Start(), - batchLoader: bl, + return bi, nil +} + +type localBatchIterator struct { + useLz4Compression bool + startRowOffset int64 + arrowSchemaBytes []byte + batches []*cli_service.TSparkArrowBatch + index int +} + +var _ BatchIterator = (*localBatchIterator)(nil) + +func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { + cnt := len(bi.batches) + bi.index++ + if bi.index < cnt { + ab := bi.batches[bi.index] + + reader := io.MultiReader( + bytes.NewReader(bi.arrowSchemaBytes), + getReader(bytes.NewReader(ab.Batch), bi.useLz4Compression), + ) + + records, err := getArrowRecords(reader, bi.startRowOffset) + if err != nil { + return &sparkArrowBatch{}, err + } + + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, ab.RowCount), + arrowRecords: records, + } + + bi.startRowOffset += ab.RowCount // advance to beginning of the next batch + + return &batch, nil } - return bi, nil + bi.index = cnt + return nil, io.EOF } -func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +func (bi *localBatchIterator) HasNext() bool { + return bi.index < len(bi.batches) +} + +func (bi *localBatchIterator) Close() { + bi.index = len(bi.batches) +} + +func newCloudBatchLoader( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, +) (*batchLoader[*cloudURL], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() @@ -98,41 +160,6 @@ func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowRe return cbl, nil } -func newLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { - - if cfg == nil { - cfg = config.WithDefaults() - } - - var startRow int64 = startRowOffset - var rowCount int64 - inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(batches)) - for i := range batches { - b := batches[i] - if b != nil { - li := &localBatch{ - Delimiter: rowscanner.NewDelimiter(startRow, b.RowCount), - batchBytes: b.Batch, - arrowSchemaBytes: arrowSchemaBytes, - compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, - } - inputChan <- li - startRow = startRow + b.RowCount - rowCount += b.RowCount - } - } - close(inputChan) - - f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) - cbl := &batchLoader[*localBatch]{ - Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), - fetcher: f, - ctx: ctx, - } - - return cbl, nil -} - type batchLoader[T interface { Fetch(ctx context.Context) (SparkArrowBatch, error) }] struct { @@ -199,8 +226,8 @@ type compressibleBatch struct { useLz4Compression bool } -func (cb compressibleBatch) getReader(r io.Reader) io.Reader { - if cb.useLz4Compression { +func getReader(r io.Reader, useLz4Compression bool) io.Reader { + if useLz4Compression { return lz4.NewReader(r) } return r @@ -237,7 +264,7 @@ func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) { defer res.Body.Close() - r := cu.compressibleBatch.getReader(res.Body) + r := getReader(res.Body, cu.compressibleBatch.useLz4Compression) records, err := getArrowRecords(r, cu.Start()) if err != nil { @@ -269,7 +296,7 @@ type localBatch struct { var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil) func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) { - r := lb.compressibleBatch.getReader(bytes.NewReader(lb.batchBytes)) + r := getReader(bytes.NewReader(lb.batchBytes), lb.compressibleBatch.useLz4Compression) r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r) records, err := getArrowRecords(r, lb.Start())