Skip to content

Commit

Permalink
Refactor into importable pkg
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandrodnm committed Apr 22, 2024
1 parent fbb7292 commit 421fba6
Show file tree
Hide file tree
Showing 7 changed files with 3,221 additions and 165 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Go

on:
push:
branches: [master, main]
pull_request:
branches: ["**"]

jobs:
build:
name: Build
runs-on: ubuntu-latest

steps:
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: "1.22"

- name: Check out code
uses: actions/checkout@v2

- name: Install dependencies
run: go mod download

- name: golangci-lint
uses: golangci/golangci-lint-action@v4
with:
version: v1.54

- name: Format
run: |
gofmt -l -s -w .
- name: Test
run: go test -v ./...
202 changes: 41 additions & 161 deletions cmd/timescaledb-parallel-copy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,11 @@ import (
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"

_ "github.com/jackc/pgx/v4/stdlib"

"github.com/timescale/timescaledb-parallel-copy/internal/batch"
"github.com/timescale/timescaledb-parallel-copy/internal/db"
"github.com/timescale/timescaledb-parallel-copy/pkg/csvcopy"
)

const (
Expand All @@ -29,7 +22,6 @@ const (
// Flag vars
var (
postgresConnect string
overrides []db.Overrideable
schemaName string
tableName string
truncate bool
Expand All @@ -52,12 +44,11 @@ var (
verbose bool
showVersion bool

rowCount int64
dbName string
)

// Parse args
func init() {
var dbName string
flag.StringVar(&postgresConnect, "connection", "host=localhost user=postgres sslmode=disable", "PostgreSQL connection url")
flag.StringVar(&dbName, "db-name", "", "Database where the destination table exists")
flag.StringVar(&tableName, "table", "test_table", "Destination table for insertions")
Expand All @@ -83,14 +74,12 @@ func init() {
flag.BoolVar(&showVersion, "version", false, "Show the version of this tool")

flag.Parse()

if dbName != "" {
overrides = append(overrides, db.OverrideDBName(dbName))
}
}

func getFullTableName() string {
return fmt.Sprintf(`"%s"."%s"`, schemaName, tableName)
type csvCopierLogger struct{}

func (l csvCopierLogger) Infof(msg string, args ...interface{}) {
log.Printf(msg, args...)
}

func main() {
Expand All @@ -99,37 +88,42 @@ func main() {
os.Exit(0)
}

if len(quoteCharacter) > 1 {
fmt.Println("ERROR: provided --quote must be a single-byte character")
os.Exit(1)
}

if len(escapeCharacter) > 1 {
fmt.Println("ERROR: provided --escape must be a single-byte character")
os.Exit(1)
copier, err := csvcopy.NewCopier(
postgresConnect,
dbName,
schemaName,
tableName,
copyOptions,
splitCharacter,
quoteCharacter,
escapeCharacter,
columns,
skipHeader,
headerLinesCnt,
workers,
limit,
batchSize,
logBatches,
reportingPeriod,
verbose,
csvcopy.WithLogger(&csvCopierLogger{}),
)
if err != nil {
log.Fatal(err)
}

if truncate { // Remove existing data from the table
dbx, err := db.Connect(postgresConnect, overrides...)
err = copier.Truncate()
if err != nil {
panic(err)
}
_, err = dbx.Exec(fmt.Sprintf("TRUNCATE %s", getFullTableName()))
if err != nil {
panic(err)
}

err = dbx.Close()
if err != nil {
panic(err)
log.Printf("failed to trunctate table: %s", err)
}
}

var reader io.Reader
if len(fromFile) > 0 {
file, err := os.Open(fromFile)
if err != nil {
log.Fatal(err)
log.Fatalf("failed to open file: %s", err)
}
defer file.Close()

Expand All @@ -138,133 +132,19 @@ func main() {
reader = os.Stdin
}

if headerLinesCnt <= 0 {
fmt.Printf("WARNING: provided --header-line-count (%d) must be greater than 0\n", headerLinesCnt)
os.Exit(1)
}

var skip int
if skipHeader {
skip = headerLinesCnt

if verbose {
fmt.Printf("Skipping the first %d lines of the input.\n", headerLinesCnt)
}
}

var wg sync.WaitGroup
batchChan := make(chan net.Buffers, workers*2)

// Generate COPY workers
for i := 0; i < workers; i++ {
wg.Add(1)
go processBatches(&wg, batchChan)
}

// Reporting thread
if reportingPeriod > (0 * time.Second) {
go report()
}

opts := batch.Options{
Size: batchSize,
Skip: skip,
Limit: limit,
}

if quoteCharacter != "" {
// we already verified the length above
opts.Quote = quoteCharacter[0]
}
if escapeCharacter != "" {
// we already verified the length above
opts.Escape = escapeCharacter[0]
}

start := time.Now()
if err := batch.Scan(reader, batchChan, opts); err != nil {
log.Fatalf("Error reading input: %s", err.Error())
result, err := copier.Copy(reader)
if err != nil {
log.Fatal("failed to copy CSV:", err)
}

close(batchChan)
wg.Wait()
end := time.Now()
took := end.Sub(start)

rowsRead := atomic.LoadInt64(&rowCount)
rowRate := float64(rowsRead) / float64(took.Seconds())

res := fmt.Sprintf("COPY %d", rowsRead)
res := fmt.Sprintf("COPY %d", result.RowsRead)
if verbose {
res += fmt.Sprintf(", took %v with %d worker(s) (mean rate %f/sec)", took, workers, rowRate)
res += fmt.Sprintf(
", took %v with %d worker(s) (mean rate %f/sec)",
result.Duration,
workers,
result.RowRate,
)
}
fmt.Println(res)
}

// report periodically prints the write rate in number of rows per second
func report() {
start := time.Now()
prevTime := start
prevRowCount := int64(0)

for now := range time.NewTicker(reportingPeriod).C {
rCount := atomic.LoadInt64(&rowCount)

took := now.Sub(prevTime)
rowrate := float64(rCount-prevRowCount) / float64(took.Seconds())
overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds())
totalTook := now.Sub(start)

fmt.Printf("at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows\n", totalTook-(totalTook%time.Second), rowrate, overallRowrate, float64(rCount))

prevRowCount = rCount
prevTime = now
}
}

// processBatches reads batches from channel c and copies them to the target
// server while tracking stats on the write.
func processBatches(wg *sync.WaitGroup, c chan net.Buffers) {
dbx, err := db.Connect(postgresConnect, overrides...)
if err != nil {
panic(err)
}
defer dbx.Close()

delimStr := "'" + splitCharacter + "'"
if splitCharacter == tabCharStr {
delimStr = "E" + delimStr
}

var quotes string
if quoteCharacter != "" {
quotes = fmt.Sprintf("QUOTE '%s'",
strings.ReplaceAll(quoteCharacter, "'", "''"))
}
if escapeCharacter != "" {
quotes = fmt.Sprintf("%s ESCAPE '%s'",
quotes, strings.ReplaceAll(escapeCharacter, "'", "''"))
}

var copyCmd string
if columns != "" {
copyCmd = fmt.Sprintf("COPY %s(%s) FROM STDIN WITH DELIMITER %s %s %s", getFullTableName(), columns, delimStr, quotes, copyOptions)
} else {
copyCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", getFullTableName(), delimStr, quotes, copyOptions)
}

for batch := range c {
start := time.Now()
rows, err := db.CopyFromLines(dbx, &batch, copyCmd)
if err != nil {
panic(err)
}
atomic.AddInt64(&rowCount, rows)

if logBatches {
took := time.Since(start)
fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, batchSize, float64(batchSize)/float64(took.Seconds()))
}
}
wg.Done()
}
Loading

0 comments on commit 421fba6

Please sign in to comment.