Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor into importable pkg #77

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Go

on:
push:
branches: [master, main]
pull_request:
branches: ["**"]

jobs:
build:
name: Build
runs-on: ubuntu-latest

steps:
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: "1.22"

- name: Check out code
uses: actions/checkout@v2

- name: Install dependencies
run: go mod download

- name: golangci-lint
uses: golangci/golangci-lint-action@v4
with:
version: v1.54

- name: Format
run: |
gofmt -l -s -w .

- name: Test
run: go test -v ./...
202 changes: 41 additions & 161 deletions cmd/timescaledb-parallel-copy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,11 @@ import (
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"

_ "github.com/jackc/pgx/v4/stdlib"

"github.com/timescale/timescaledb-parallel-copy/internal/batch"
"github.com/timescale/timescaledb-parallel-copy/internal/db"
"github.com/timescale/timescaledb-parallel-copy/pkg/csvcopy"
)

const (
Expand All @@ -29,7 +22,6 @@ const (
// Flag vars
var (
postgresConnect string
overrides []db.Overrideable
schemaName string
tableName string
truncate bool
Expand All @@ -52,12 +44,11 @@ var (
verbose bool
showVersion bool

rowCount int64
dbName string
)

// Parse args
func init() {
var dbName string
flag.StringVar(&postgresConnect, "connection", "host=localhost user=postgres sslmode=disable", "PostgreSQL connection url")
flag.StringVar(&dbName, "db-name", "", "Database where the destination table exists")
flag.StringVar(&tableName, "table", "test_table", "Destination table for insertions")
Expand All @@ -83,14 +74,12 @@ func init() {
flag.BoolVar(&showVersion, "version", false, "Show the version of this tool")

flag.Parse()

if dbName != "" {
overrides = append(overrides, db.OverrideDBName(dbName))
}
}

func getFullTableName() string {
return fmt.Sprintf(`"%s"."%s"`, schemaName, tableName)
type csvCopierLogger struct{}

func (l csvCopierLogger) Infof(msg string, args ...interface{}) {
log.Printf(msg, args...)
}

func main() {
Expand All @@ -99,37 +88,42 @@ func main() {
os.Exit(0)
}

if len(quoteCharacter) > 1 {
fmt.Println("ERROR: provided --quote must be a single-byte character")
os.Exit(1)
}

if len(escapeCharacter) > 1 {
fmt.Println("ERROR: provided --escape must be a single-byte character")
os.Exit(1)
copier, err := csvcopy.NewCopier(
postgresConnect,
dbName,
schemaName,
tableName,
copyOptions,
splitCharacter,
quoteCharacter,
escapeCharacter,
columns,
skipHeader,
headerLinesCnt,
workers,
limit,
batchSize,
logBatches,
reportingPeriod,
verbose,
csvcopy.WithLogger(&csvCopierLogger{}),
)
if err != nil {
log.Fatal(err)
}

if truncate { // Remove existing data from the table
dbx, err := db.Connect(postgresConnect, overrides...)
err = copier.Truncate()
if err != nil {
panic(err)
}
_, err = dbx.Exec(fmt.Sprintf("TRUNCATE %s", getFullTableName()))
if err != nil {
panic(err)
}

err = dbx.Close()
if err != nil {
panic(err)
log.Printf("failed to trunctate table: %s", err)
}
}

var reader io.Reader
if len(fromFile) > 0 {
file, err := os.Open(fromFile)
if err != nil {
log.Fatal(err)
log.Fatalf("failed to open file: %s", err)
}
defer file.Close()

Expand All @@ -138,133 +132,19 @@ func main() {
reader = os.Stdin
}

if headerLinesCnt <= 0 {
fmt.Printf("WARNING: provided --header-line-count (%d) must be greater than 0\n", headerLinesCnt)
os.Exit(1)
}

var skip int
if skipHeader {
skip = headerLinesCnt

if verbose {
fmt.Printf("Skipping the first %d lines of the input.\n", headerLinesCnt)
}
}

var wg sync.WaitGroup
batchChan := make(chan net.Buffers, workers*2)

// Generate COPY workers
for i := 0; i < workers; i++ {
wg.Add(1)
go processBatches(&wg, batchChan)
}

// Reporting thread
if reportingPeriod > (0 * time.Second) {
go report()
}

opts := batch.Options{
Size: batchSize,
Skip: skip,
Limit: limit,
}

if quoteCharacter != "" {
// we already verified the length above
opts.Quote = quoteCharacter[0]
}
if escapeCharacter != "" {
// we already verified the length above
opts.Escape = escapeCharacter[0]
}

start := time.Now()
if err := batch.Scan(reader, batchChan, opts); err != nil {
log.Fatalf("Error reading input: %s", err.Error())
result, err := copier.Copy(reader)
if err != nil {
log.Fatal("failed to copy CSV:", err)
}

close(batchChan)
wg.Wait()
end := time.Now()
took := end.Sub(start)

rowsRead := atomic.LoadInt64(&rowCount)
rowRate := float64(rowsRead) / float64(took.Seconds())

res := fmt.Sprintf("COPY %d", rowsRead)
res := fmt.Sprintf("COPY %d", result.RowsRead)
if verbose {
res += fmt.Sprintf(", took %v with %d worker(s) (mean rate %f/sec)", took, workers, rowRate)
res += fmt.Sprintf(
", took %v with %d worker(s) (mean rate %f/sec)",
result.Duration,
workers,
result.RowRate,
)
}
fmt.Println(res)
}

// report periodically prints the write rate in number of rows per second
func report() {
start := time.Now()
prevTime := start
prevRowCount := int64(0)

for now := range time.NewTicker(reportingPeriod).C {
rCount := atomic.LoadInt64(&rowCount)

took := now.Sub(prevTime)
rowrate := float64(rCount-prevRowCount) / float64(took.Seconds())
overallRowrate := float64(rCount) / float64(now.Sub(start).Seconds())
totalTook := now.Sub(start)

fmt.Printf("at %v, row rate %0.2f/sec (period), row rate %0.2f/sec (overall), %E total rows\n", totalTook-(totalTook%time.Second), rowrate, overallRowrate, float64(rCount))

prevRowCount = rCount
prevTime = now
}
}

// processBatches reads batches from channel c and copies them to the target
// server while tracking stats on the write.
func processBatches(wg *sync.WaitGroup, c chan net.Buffers) {
dbx, err := db.Connect(postgresConnect, overrides...)
if err != nil {
panic(err)
}
defer dbx.Close()

delimStr := "'" + splitCharacter + "'"
if splitCharacter == tabCharStr {
delimStr = "E" + delimStr
}

var quotes string
if quoteCharacter != "" {
quotes = fmt.Sprintf("QUOTE '%s'",
strings.ReplaceAll(quoteCharacter, "'", "''"))
}
if escapeCharacter != "" {
quotes = fmt.Sprintf("%s ESCAPE '%s'",
quotes, strings.ReplaceAll(escapeCharacter, "'", "''"))
}

var copyCmd string
if columns != "" {
copyCmd = fmt.Sprintf("COPY %s(%s) FROM STDIN WITH DELIMITER %s %s %s", getFullTableName(), columns, delimStr, quotes, copyOptions)
} else {
copyCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", getFullTableName(), delimStr, quotes, copyOptions)
}

for batch := range c {
start := time.Now()
rows, err := db.CopyFromLines(dbx, &batch, copyCmd)
if err != nil {
panic(err)
}
atomic.AddInt64(&rowCount, rows)

if logBatches {
took := time.Since(start)
fmt.Printf("[BATCH] took %v, batch size %d, row rate %f/sec\n", took, batchSize, float64(batchSize)/float64(took.Seconds()))
}
}
wg.Done()
}
Loading
Loading