diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 2f52efa..41fb459 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -131,8 +131,11 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { link.StartRowOffset, link.RowCount, ) + + cancelCtx, cancelFn := context.WithCancel(bi.ctx) task := &cloudFetchDownloadTask{ - ctx: bi.ctx, + ctx: cancelCtx, + cancel: cancelFn, useLz4Compression: bi.cfg.UseLz4Compression, link: link, resultChan: make(chan cloudFetchDownloadTaskResult), @@ -147,7 +150,17 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { return nil, io.EOF } - return task.GetResult() + batch, err := task.GetResult() + + // once we've got an errored out task - cancel the remaining ones + if err != nil { + bi.Close() + return nil, err + } + + // explicitly call cancel function on successfully completed task to avoid context leak + task.cancel() + return batch, nil } func (bi *cloudBatchIterator) HasNext() bool { @@ -155,9 +168,11 @@ func (bi *cloudBatchIterator) HasNext() bool { } func (bi *cloudBatchIterator) Close() { - bi.pendingLinks.Clear() // Clear the list - // TODO: Cancel all download tasks - bi.downloadTasks.Clear() // Clear the list + bi.pendingLinks.Clear() + for bi.downloadTasks.Len() > 0 { + task := bi.downloadTasks.Dequeue() + task.cancel() + } } type cloudFetchDownloadTaskResult struct { @@ -167,6 +182,7 @@ type cloudFetchDownloadTaskResult struct { type cloudFetchDownloadTask struct { ctx context.Context + cancel context.CancelFunc useLz4Compression bool minTimeToExpiry time.Duration link *cli_service.TSparkArrowResultLink @@ -180,9 +196,10 @@ func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { if ok { if result.err != nil { logger.Debug().Msgf( - "CloudFetch: failed to download link at offset %d row count %d", + "CloudFetch: failed to download link at offset %d row count %d, reason: %s", link.StartRowOffset, link.RowCount, + result.err.Error(), ) return nil, result.err }