Skip to content

Commit

Permalink
Split out regular Arrow results handling from CloudFetch
Browse files Browse the repository at this point in the history
Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko committed Jul 9, 2024
1 parent 92d1ef4 commit 72540a9
Showing 1 changed file with 76 additions and 49 deletions.
125 changes: 76 additions & 49 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ type BatchLoader interface {
Close()
}

func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) {
func NewCloudBatchIterator(
ctx context.Context,
files []*cli_service.TSparkArrowResultLink,
startRowOffset int64,
cfg *config.Config,
) (BatchIterator, dbsqlerr.DBError) {
bl, err := newCloudBatchLoader(ctx, files, startRowOffset, cfg)
if err != nil {
return nil, err
Expand All @@ -47,21 +52,78 @@ func NewCloudBatchIterator(ctx context.Context, files []*cli_service.TSparkArrow
return bi, nil
}

func NewLocalBatchIterator(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (BatchIterator, dbsqlerr.DBError) {
bl, err := newLocalBatchLoader(ctx, batches, startRowOffset, arrowSchemaBytes, cfg)
if err != nil {
return nil, err
func NewLocalBatchIterator(
ctx context.Context,
batches []*cli_service.TSparkArrowBatch,
startRowOffset int64,
arrowSchemaBytes []byte,
cfg *config.Config,
) (BatchIterator, dbsqlerr.DBError) {
bi := &localBatchIterator{
useLz4Compression: cfg.UseLz4Compression,
startRowOffset: startRowOffset,
arrowSchemaBytes: arrowSchemaBytes,
batches: batches,
index: -1,
}

bi := &batchIterator{
nextBatchStart: bl.Start(),
batchLoader: bl,
return bi, nil
}

type localBatchIterator struct {
useLz4Compression bool
startRowOffset int64
arrowSchemaBytes []byte
batches []*cli_service.TSparkArrowBatch
index int
}

var _ BatchIterator = (*localBatchIterator)(nil)

func (bi *localBatchIterator) Next() (SparkArrowBatch, error) {
cnt := len(bi.batches)
bi.index++
if bi.index < cnt {
ab := bi.batches[bi.index]

reader := io.MultiReader(
bytes.NewReader(bi.arrowSchemaBytes),
getReader(bytes.NewReader(ab.Batch), bi.useLz4Compression),
)

records, err := getArrowRecords(reader, bi.startRowOffset)
if err != nil {
return &sparkArrowBatch{}, err
}

batch := sparkArrowBatch{
Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, ab.RowCount),
arrowRecords: records,
}

bi.startRowOffset += ab.RowCount // advance to beginning of the next batch

return &batch, nil
}

return bi, nil
bi.index = cnt
return nil, io.EOF
}

func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) {
func (bi *localBatchIterator) HasNext() bool {
return bi.index < len(bi.batches)
}

func (bi *localBatchIterator) Close() {
bi.index = len(bi.batches)
}

func newCloudBatchLoader(
ctx context.Context,
files []*cli_service.TSparkArrowResultLink,
startRowOffset int64,
cfg *config.Config,
) (*batchLoader[*cloudURL], dbsqlerr.DBError) {

if cfg == nil {
cfg = config.WithDefaults()
Expand Down Expand Up @@ -98,41 +160,6 @@ func newCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowRe
return cbl, nil
}

func newLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) {

if cfg == nil {
cfg = config.WithDefaults()
}

var startRow int64 = startRowOffset
var rowCount int64
inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(batches))
for i := range batches {
b := batches[i]
if b != nil {
li := &localBatch{
Delimiter: rowscanner.NewDelimiter(startRow, b.RowCount),
batchBytes: b.Batch,
arrowSchemaBytes: arrowSchemaBytes,
compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression},
}
inputChan <- li
startRow = startRow + b.RowCount
rowCount += b.RowCount
}
}
close(inputChan)

f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan)
cbl := &batchLoader[*localBatch]{
Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount),
fetcher: f,
ctx: ctx,
}

return cbl, nil
}

type batchLoader[T interface {
Fetch(ctx context.Context) (SparkArrowBatch, error)
}] struct {
Expand Down Expand Up @@ -199,8 +226,8 @@ type compressibleBatch struct {
useLz4Compression bool
}

func (cb compressibleBatch) getReader(r io.Reader) io.Reader {
if cb.useLz4Compression {
func getReader(r io.Reader, useLz4Compression bool) io.Reader {
if useLz4Compression {
return lz4.NewReader(r)
}
return r
Expand Down Expand Up @@ -237,7 +264,7 @@ func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) {

defer res.Body.Close()

r := cu.compressibleBatch.getReader(res.Body)
r := getReader(res.Body, cu.compressibleBatch.useLz4Compression)

records, err := getArrowRecords(r, cu.Start())
if err != nil {
Expand Down Expand Up @@ -269,7 +296,7 @@ type localBatch struct {
var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil)

func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) {
r := lb.compressibleBatch.getReader(bytes.NewReader(lb.batchBytes))
r := getReader(bytes.NewReader(lb.batchBytes), lb.compressibleBatch.useLz4Compression)
r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r)

records, err := getArrowRecords(r, lb.Start())
Expand Down

0 comments on commit 72540a9

Please sign in to comment.