diff --git a/cmd/timescaledb-parallel-copy/main.go b/cmd/timescaledb-parallel-copy/main.go index 4495d7f..70fa764 100644 --- a/cmd/timescaledb-parallel-copy/main.go +++ b/cmd/timescaledb-parallel-copy/main.go @@ -33,10 +33,12 @@ var ( quoteCharacter string escapeCharacter string - fromFile string - columns string - skipHeader bool - headerLinesCnt int + fromFile string + columns string + skipHeader bool + headerLinesCnt int + batchErrorOutputDir string + skipBatchErrors bool workers int limit int64 @@ -67,6 +69,9 @@ func init() { flag.BoolVar(&skipHeader, "skip-header", false, "Skip the first line of the input") flag.IntVar(&headerLinesCnt, "header-line-count", 1, "Number of header lines") + flag.StringVar(&batchErrorOutputDir, "batch-error-output-dir", "", "directory to store batch errors. Settings this will save a .csv file with the contents of the batch that failed and continue with the rest of the data.") + flag.BoolVar(&skipBatchErrors, "skip-batch-errors", false, "if true, the copy will continue even if a batch fails") + flag.IntVar(&batchSize, "batch-size", 5000, "Number of rows per insert") flag.Int64Var(&limit, "limit", 0, "Number of rows to insert overall; 0 means to insert all") flag.IntVar(&workers, "workers", 1, "Number of parallel requests to make") @@ -94,8 +99,10 @@ func main() { if dbName != "" { log.Fatalf("Error: Deprecated flag -db-name is being used. Update -connection to connect to the given database") } + logger := &csvCopierLogger{} + opts := []csvcopy.Option{ - csvcopy.WithLogger(&csvCopierLogger{}), + csvcopy.WithLogger(logger), csvcopy.WithSchemaName(schemaName), csvcopy.WithCopyOptions(copyOptions), csvcopy.WithSplitCharacter(splitCharacter), @@ -110,6 +117,19 @@ func main() { csvcopy.WithVerbose(verbose), } + batchErrorHandler := csvcopy.BatchHandlerError() + if skipBatchErrors { + batchErrorHandler = csvcopy.BatchHandlerNoop() + } + if batchErrorOutputDir != "" { + log.Printf("batch errors will be stored at %s", batchErrorOutputDir) + batchErrorHandler = csvcopy.BatchHandlerSaveToFile(batchErrorOutputDir, batchErrorHandler) + } + if verbose || skipBatchErrors { + batchErrorHandler = csvcopy.BatchHandlerLog(logger, batchErrorHandler) + } + opts = append(opts, csvcopy.WithBatchErrorHandler(batchErrorHandler)) + if skipHeader { opts = append(opts, csvcopy.WithSkipHeaderCount(headerLinesCnt), diff --git a/internal/batch/scan.go b/pkg/batch/scan.go similarity index 77% rename from internal/batch/scan.go rename to pkg/batch/scan.go index 7936d31..96d5d0a 100644 --- a/internal/batch/scan.go +++ b/pkg/batch/scan.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "log" "net" ) @@ -22,18 +23,55 @@ type Options struct { type Batch struct { Data net.Buffers Location Location + + // backup holds the same data as Data. It is used to rewind if something goes wrong + // Because it copies the slice, the memory is not duplicated + // Because we only read data, the underlaying memory is not modified either + backup net.Buffers +} + +func NewBatch(data net.Buffers, location Location) Batch { + b := Batch{ + Data: data, + Location: location, + } + b.snapshot() + return b +} + +func (b *Batch) snapshot() { + b.backup = net.Buffers{} + b.backup = append(b.backup, b.Data...) +} + +// Makes data available again to read +func (b *Batch) Rewind() { + b.Data = net.Buffers{} + b.Data = append(b.Data, b.backup...) } // Location positions a batch within the original data type Location struct { + // StartRow represents the index of the row where the batch starts. + // First row of the file is row 0 + // The header counts as a line StartRow int64 - Length int + // RowCount is the number of rows in the batch + RowCount int + // ByteOffset is the byte position in the original file. + // It can be used with ReadAt to process the same batch again. + ByteOffset int + // ByteLen represents the number of bytes for the batch. + // It can be used to know how big the batch is and read it accordingly + ByteLen int } -func NewLocation(rowsRead int64, bufferedRows int, skip int) Location { +func NewLocation(rowsRead int64, bufferedRows int, skip int, byteOffset int, byteLen int) Location { return Location{ - StartRow: rowsRead - int64(bufferedRows) + int64(skip), - Length: bufferedRows, + StartRow: rowsRead - int64(bufferedRows) + int64(skip) - 1, // Index rows starting at 0 + RowCount: bufferedRows, + ByteOffset: byteOffset, + ByteLen: byteLen, } } @@ -49,7 +87,8 @@ func NewLocation(rowsRead int64, bufferedRows int, skip int) Location { // opts.Escape as the QUOTE and ESCAPE characters used for the CSV input. func Scan(ctx context.Context, r io.Reader, out chan<- Batch, opts Options) error { var rowsRead int64 - reader := bufio.NewReader(r) + counter := &CountReader{Reader: r} + reader := bufio.NewReader(counter) for skip := opts.Skip; skip > 0; { // The use of ReadLine() here avoids copying or buffering data that @@ -62,7 +101,6 @@ func Scan(ctx context.Context, r io.Reader, out chan<- Batch, opts Options) erro } else if err != nil { return fmt.Errorf("skipping header: %w", err) } - if !isPrefix { // We pulled a full row from the buffer. skip-- @@ -91,6 +129,7 @@ func Scan(ctx context.Context, r io.Reader, out chan<- Batch, opts Options) erro bufs := make(net.Buffers, 0, opts.Size) var bufferedRows int + byteStart := counter.Total - reader.Buffered() for { eol := false @@ -130,16 +169,18 @@ func Scan(ctx context.Context, r io.Reader, out chan<- Batch, opts Options) erro } if bufferedRows >= opts.Size { // dispatch to COPY worker & reset + byteEnd := counter.Total - reader.Buffered() select { - case out <- Batch{ - Data: bufs, - Location: NewLocation(rowsRead, bufferedRows, opts.Skip), - }: + case out <- NewBatch( + bufs, + NewLocation(rowsRead, bufferedRows, opts.Skip, byteStart, byteEnd-byteStart), + ): case <-ctx.Done(): return ctx.Err() } bufs = make(net.Buffers, 0, opts.Size) bufferedRows = 0 + byteStart = byteEnd } } @@ -153,15 +194,17 @@ func Scan(ctx context.Context, r io.Reader, out chan<- Batch, opts Options) erro // Finished reading input, make sure last batch goes out. if len(bufs) > 0 { + byteEnd := counter.Total - reader.Buffered() select { - case out <- Batch{ - Data: bufs, - Location: NewLocation(rowsRead, bufferedRows, opts.Skip), - }: + case out <- NewBatch( + bufs, + NewLocation(rowsRead, bufferedRows, opts.Skip, byteStart, byteEnd-byteStart), + ): case <-ctx.Done(): return ctx.Err() } } + log.Print("total rows ", rowsRead) return nil } @@ -257,3 +300,15 @@ func (c *csvRowState) NeedsMore() bool { // c.inQuote is also true. return c.inQuote } + +// CountReader is a wrapper that counts how many bytes have been read from the given reader +type CountReader struct { + Reader io.Reader + Total int +} + +func (c *CountReader) Read(b []byte) (int, error) { + n, err := c.Reader.Read(b) + c.Total += n + return n, err +} diff --git a/internal/batch/scan_internal_test.go b/pkg/batch/scan_internal_test.go similarity index 100% rename from internal/batch/scan_internal_test.go rename to pkg/batch/scan_internal_test.go diff --git a/internal/batch/scan_test.go b/pkg/batch/scan_test.go similarity index 88% rename from internal/batch/scan_test.go rename to pkg/batch/scan_test.go index f7b8b01..1786edb 100644 --- a/internal/batch/scan_test.go +++ b/pkg/batch/scan_test.go @@ -6,11 +6,14 @@ import ( "errors" "fmt" "io" + "net" "reflect" "strings" "testing" - "github.com/timescale/timescaledb-parallel-copy/internal/batch" + "github.com/stretchr/testify/require" + "github.com/timescale/timescaledb-parallel-copy/pkg/batch" + "golang.org/x/exp/rand" ) func TestScan(t *testing.T) { @@ -428,3 +431,43 @@ func BenchmarkScan(b *testing.B) { } } } + +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,") + +func RandString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func TestRewind(t *testing.T) { + randomData := RandString(5000) + data := net.Buffers(bytes.Split([]byte(randomData), []byte(","))) + + batch := batch.NewBatch(data, batch.NewLocation(0, 0, 0, 0, 0)) + + var err error + // reads all the data + buf := bytes.Buffer{} + _, err = buf.ReadFrom(&batch.Data) + require.NoError(t, err) + require.Equal(t, strings.Replace(randomData, ",", "", -1), buf.String()) + require.Empty(t, batch.Data) + + // Reading again returns nothing + buf = bytes.Buffer{} + _, err = buf.ReadFrom(&batch.Data) + require.NoError(t, err) + require.Empty(t, buf.String()) + require.Empty(t, batch.Data) + + // Reading again after rewind, returns all data + batch.Rewind() + buf = bytes.Buffer{} + _, err = buf.ReadFrom(&batch.Data) + require.NoError(t, err) + require.Equal(t, strings.Replace(randomData, ",", "", -1), buf.String()) + +} diff --git a/pkg/csvcopy/batch_error.go b/pkg/csvcopy/batch_error.go new file mode 100644 index 0000000..045e848 --- /dev/null +++ b/pkg/csvcopy/batch_error.go @@ -0,0 +1,62 @@ +package csvcopy + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/timescale/timescaledb-parallel-copy/pkg/batch" +) + +// BatchHandlerSaveToFile saves the errors to the given directory using the batch start row as file name. +func BatchHandlerSaveToFile(dir string, next BatchErrorHandler) BatchErrorHandler { + return BatchErrorHandler(func(batch batch.Batch, reason error) error { + err := os.MkdirAll(dir, os.ModePerm) + if err != nil { + return fmt.Errorf("failed to ensure directory exists: %w", err) + } + + fileName := fmt.Sprintf("%d.csv", batch.Location.StartRow) + path := filepath.Join(dir, fileName) + + dst, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create file to store batch error, %w", err) + } + defer dst.Close() + + batch.Rewind() + _, err = io.Copy(dst, &batch.Data) + if err != nil { + return fmt.Errorf("failed to write file to store batch error, %w", err) + } + + if next != nil { + return next(batch, reason) + } + return nil + }) +} + +// BatchHandlerLog prints a log line that reports the error in the given batch +func BatchHandlerLog(log Logger, next BatchErrorHandler) BatchErrorHandler { + return BatchErrorHandler(func(batch batch.Batch, reason error) error { + log.Infof("Batch %d, starting at byte %d with len %d, has error: %s", batch.Location.StartRow, batch.Location.ByteOffset, batch.Location.ByteLen, reason.Error()) + + if next != nil { + return next(batch, reason) + } + return nil + }) +} + +// BatchHandlerNoop no operation +func BatchHandlerNoop() BatchErrorHandler { + return BatchErrorHandler(func(_ batch.Batch, _ error) error { return nil }) +} + +// BatchHandlerError fails the process +func BatchHandlerError() BatchErrorHandler { + return BatchErrorHandler(func(_ batch.Batch, err error) error { return err }) +} diff --git a/pkg/csvcopy/csvcopy.go b/pkg/csvcopy/csvcopy.go index b9ad4fa..a102431 100644 --- a/pkg/csvcopy/csvcopy.go +++ b/pkg/csvcopy/csvcopy.go @@ -2,7 +2,6 @@ package csvcopy import ( "context" - "errors" "fmt" "io" "regexp" @@ -14,8 +13,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" _ "github.com/jackc/pgx/v5/stdlib" - "github.com/timescale/timescaledb-parallel-copy/internal/batch" "github.com/timescale/timescaledb-parallel-copy/internal/db" + "github.com/timescale/timescaledb-parallel-copy/pkg/batch" ) const TAB_CHAR_STR = "\\t" @@ -47,6 +46,8 @@ type Copier struct { verbose bool skip int rowCount int64 + + failHandler BatchErrorHandler } func NewCopier( @@ -78,7 +79,7 @@ func NewCopier( for _, o := range options { err := o(copier) if err != nil { - return nil, fmt.Errorf("Error processing option, %T, %w", o, err) + return nil, fmt.Errorf("failed to execute option %T: %w", o, err) } } @@ -183,48 +184,56 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { rowsRead := atomic.LoadInt64(&c.rowCount) rowRate := float64(rowsRead) / float64(took.Seconds()) - return Result{ + result := Result{ RowsRead: rowsRead, Duration: took, RowRate: rowRate, - }, err + } + + if err != nil { + return result, err + } + return result, nil } type ErrAtRow struct { Err error - Row int64 + // Row is the row reported by PgError + // The value is relative to the location + Row int + BatchLocation batch.Location } -func ErrAtRowFromPGError(pgerr *pgconn.PgError, offset int64) *ErrAtRow { +// RowAtLocation returns the row number taking into account the batch location +// so the number matches the original file +func (err *ErrAtRow) RowAtLocation() int { + if err.Row == -1 { + return -1 + } + return err.Row + int(err.BatchLocation.StartRow) +} + +func ExtractRowFrom(pgerr *pgconn.PgError) int { // Example of Where field // "COPY metrics, line 1, column value: \"hello\"" match := regexp.MustCompile(`line (\d+)`).FindStringSubmatch(pgerr.Where) if len(match) != 2 { - return &ErrAtRow{ - Err: pgerr, - Row: -1, - } + return -1 } line, err := strconv.Atoi(match[1]) if err != nil { - return &ErrAtRow{ - Err: pgerr, - Row: -1, - } + return -1 } - return &ErrAtRow{ - Err: pgerr, - Row: offset + int64(line), - } + return line } func (e ErrAtRow) Error() string { if e.Err != nil { - return fmt.Sprintf("at row %d, error %s", e.Row, e.Err.Error()) + return fmt.Sprintf("at row %d, error %s", e.RowAtLocation(), e.Err.Error()) } - return fmt.Sprintf("error at row %d", e.Row) + return fmt.Sprintf("error at row %d", e.RowAtLocation()) } func (e ErrAtRow) Unwrap() error { @@ -274,24 +283,40 @@ func (c *Copier) processBatches(ctx context.Context, ch chan batch.Batch) (err e if !ok { return } + start := time.Now() rows, err := db.CopyFromLines(ctx, dbx, &batch.Data, copyCmd) if err != nil { - pgErr := &pgconn.PgError{} - if errors.As(err, &pgErr) { - return ErrAtRowFromPGError(pgErr, batch.Location.StartRow) + err = c.handleCopyError(batch, err) + if err != nil { + return err } - return fmt.Errorf("[BATCH] starting at row %d: %w", batch.Location.StartRow, err) } atomic.AddInt64(&c.rowCount, rows) if c.logBatches { took := time.Since(start) - fmt.Printf("[BATCH] starting at row %d, took %v, batch size %d, row rate %f/sec\n", batch.Location.StartRow, took, batch.Location.Length, float64(batch.Location.Length)/float64(took.Seconds())) + fmt.Printf("[BATCH] starting at row %d, took %v, batch size %d, row rate %f/sec\n", batch.Location.StartRow, took, batch.Location.RowCount, float64(batch.Location.RowCount)/float64(took.Seconds())) } } } } +func (c *Copier) handleCopyError(batch batch.Batch, err error) error { + errAt := &ErrAtRow{ + Err: err, + BatchLocation: batch.Location, + } + if pgerr, ok := err.(*pgconn.PgError); ok { + errAt.Row = ExtractRowFrom(pgerr) + } + + if c.failHandler != nil { + batch.Rewind() + return c.failHandler(batch, errAt) + } + return errAt + +} // report periodically prints the write rate in number of rows per second func (c *Copier) report(ctx context.Context) { diff --git a/pkg/csvcopy/csvcopy_test.go b/pkg/csvcopy/csvcopy_test.go index 04d4aee..3663219 100644 --- a/pkg/csvcopy/csvcopy_test.go +++ b/pkg/csvcopy/csvcopy_test.go @@ -1,8 +1,10 @@ package csvcopy import ( + "bytes" "context" "encoding/csv" + "fmt" "os" "testing" "time" @@ -13,6 +15,7 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" + "github.com/timescale/timescaledb-parallel-copy/pkg/batch" ) func TestWriteDataToCSV(t *testing.T) { @@ -154,7 +157,84 @@ func TestErrorAtRow(t *testing.T) { writer.Flush() - copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value")) + copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithBatchSize(2)) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + _, err = copier.Copy(context.Background(), reader) + assert.Error(t, err) + errAtRow := &ErrAtRow{} + assert.ErrorAs(t, err, &errAtRow) + assert.EqualValues(t, 3, errAtRow.RowAtLocation()) + + prev := `42,xasev,4.2 +24,qased,2.4 +` + assert.EqualValues(t, len(prev), errAtRow.BatchLocation.ByteOffset) + batch := `24,qased,2.4 +24,qased,hello +` + assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) +} + +func TestErrorAtRowWithHeader(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:15.3-alpine"), + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := pgContainer.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate pgContainer: %s", err) + } + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + {"number", "text", "float"}, + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + if err := writer.Write(record); err != nil { + t.Fatalf("Error writing record to CSV: %v", err) + } + } + + writer.Flush() + + copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithSkipHeader(true), WithBatchSize(2)) require.NoError(t, err) reader, err := os.Open(tmpfile.Name()) require.NoError(t, err) @@ -162,7 +242,17 @@ func TestErrorAtRow(t *testing.T) { assert.Error(t, err) errAtRow := &ErrAtRow{} assert.ErrorAs(t, err, &errAtRow) - assert.EqualValues(t, 4, errAtRow.Row) + assert.EqualValues(t, 4, errAtRow.RowAtLocation()) + + prev := `number,text,float +42,xasev,4.2 +24,qased,2.4 +` + assert.EqualValues(t, len(prev), errAtRow.BatchLocation.ByteOffset) + batch := `24,qased,2.4 +24,qased,hello +` + assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) } func TestWriteReportProgress(t *testing.T) { @@ -255,3 +345,166 @@ func TestWriteReportProgress(t *testing.T) { require.NoError(t, err) assert.Equal(t, []interface{}{int32(24), "qased", 2.4}, results) } + +func TestFailedBatchHandler(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:15.3-alpine"), + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + require.NoError(t, err) + + t.Cleanup(func() { + err := pgContainer.Terminate(ctx) + require.NoError(t, err) + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + // Batch 1 + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + // Batch 2 + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + // Batch 3 + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + if err := writer.Write(record); err != nil { + t.Fatalf("Error writing record to CSV: %v", err) + } + } + + writer.Flush() + fs := &MockErrorHandler{} + + copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithBatchSize(2), WithBatchErrorHandler(fs.HandleError)) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + result, err := copier.Copy(context.Background(), reader) + require.NoError(t, err) + require.EqualValues(t, 4, result.RowsRead) + + require.Contains(t, fs.Files, 1) + require.Equal(t, fs.Files[1].String(), "24,qased,2.4\n24,qased,hello\n") + require.Contains(t, fs.Errors, 1) + assert.EqualValues(t, fs.Errors[1].(*ErrAtRow).RowAtLocation(), 3) + assert.EqualValues(t, fs.Errors[1].(*ErrAtRow).BatchLocation.RowCount, 2) + assert.EqualValues(t, fs.Errors[1].(*ErrAtRow).BatchLocation.ByteOffset, 26) + assert.EqualValues(t, fs.Errors[1].(*ErrAtRow).BatchLocation.ByteLen, len("24,qased,2.4\n24,qased,hello\n")) +} + +type MockErrorHandler struct { + Files map[int]*bytes.Buffer + Errors map[int]error +} + +func (fs *MockErrorHandler) HandleError(batch batch.Batch, reason error) error { + if fs.Files == nil { + fs.Files = map[int]*bytes.Buffer{} + } + if fs.Errors == nil { + fs.Errors = map[int]error{} + } + buf := &bytes.Buffer{} + _, err := buf.ReadFrom(&batch.Data) + if err != nil { + return err + } + fs.Files[int(batch.Location.StartRow)] = buf + fs.Errors[int(batch.Location.StartRow)] = reason + return nil +} + +func TestFailedBatchHandlerFailure(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:15.3-alpine"), + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + require.NoError(t, err) + + t.Cleanup(func() { + err := pgContainer.Terminate(ctx) + require.NoError(t, err) + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + // Batch 1 + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + // Batch 2 + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + // Batch 3 + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + err := writer.Write(record) + require.NoError(t, err, "Error writing record to CSV") + } + + writer.Flush() + + copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithBatchSize(2), WithBatchErrorHandler(func(batch batch.Batch, err error) error { + return fmt.Errorf("couldn't handle error %w", err) + })) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + _, err = copier.Copy(context.Background(), reader) + require.Error(t, err) + require.ErrorContains(t, err, "couldn't handle error") + +} diff --git a/pkg/csvcopy/options.go b/pkg/csvcopy/options.go index bc6328e..e0bd8af 100644 --- a/pkg/csvcopy/options.go +++ b/pkg/csvcopy/options.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" "time" + + "github.com/timescale/timescaledb-parallel-copy/pkg/batch" ) type Option func(c *Copier) error @@ -186,3 +188,17 @@ func WithSchemaName(schema string) Option { return nil } } + +// BatchErrorHandler is how batch errors are handled +// It has the batch data so it can be inspected +// The error has the failure reason +// If the error is not handled properly, returning an error will stop the workers +type BatchErrorHandler func(batch batch.Batch, err error) error + +// WithBatchErrorHandler specifies which fail handler implementation to use +func WithBatchErrorHandler(handler BatchErrorHandler) Option { + return func(c *Copier) error { + c.failHandler = handler + return nil + } +}