Skip to content

Commit

Permalink
Also add support for COPY in + fix close after partial read + remaini…
Browse files Browse the repository at this point in the history
…ng_row_size
  • Loading branch information
Yorhel committed Dec 21, 2023
1 parent 8d95238 commit ed26146
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 85 deletions.
44 changes: 18 additions & 26 deletions spec/pg/connection_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -142,40 +142,32 @@ describe PG, "#clear_time_zone_cache" do
end
end

describe PG, "COPY out" do
it "supports COPY TO STDOUT data transfer" do
describe PG, "COPY" do
it "properly handles partial reads and consumes data on early close" do
with_connection do |db|
io = db.copy_out "COPY (SELECT 'text', NULL, 1) TO STDOUT"
io.gets_to_end.should eq "text\t\\N\t1\n"
io = db.exec_copy "COPY (VALUES (1), (333)) TO STDOUT"
io.read_char.should eq '1'
io.read_char.should eq '\n'
io.read_char.should eq '3'
io.read_char.should eq '3'
io.close
db.scalar("select 1").should eq(1)
end
end

it "propely consumes data on early close" do
if "survives a COPY FROM STDIN and COPY TO STDOUT round-trip"
with_connection do |db|
io = db.copy_out "COPY (SELECT * FROM generate_series(1, 100) x) TO STDOUT"
io.gets.should eq "1"
io.gets.should eq "2"
io.gets.should eq "3"
io.close
db.scalar("select 1").should eq(1)
end
end
data = "123\tdata\n\\N\t\\N\n"
db.exec("CREATE TEMPORARY TABLE IF NOT EXISTS copy_test (a int, b text)")

it "properly handles partial reads" do
with_connection do |db|
io = db.copy_out "COPY (VALUES (1), (333)) TO STDOUT"
io.read_char.should eq '1'
io.read_char.should eq '\n'
io.read_char.should eq '3'
io.read_char.should eq '3'
io.read_char.should eq '3'
io.read_char.should eq '\n'
io.read_char.should eq nil
io.read_char.should eq nil
io.close
db.scalar("select 1").should eq(1)
wr = db.exec_copy "COPY copy_test FROM STDIN"
wr << data
wr.close

rd = db.exec_copy "COPY copy_test TO STDOUT"
rd.gets_to_end.should eq data

db.exec("DROP TABLE copy_test")
end
end
end
18 changes: 11 additions & 7 deletions src/pg/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@ module PG
nil
end

# Execute a "COPY .. TO STDOUT" query and return an IO object to read from.
# The IO *must* be closed before using the connection again.
# Execute a "COPY" query and return an IO object to read from or write to,
# depending on the query.
#
# ```
# io = conn.copy_out "COPY table TO STDOUT"
# data = io.gets_to_end
# io.close
# data = conn.exec_copy("COPY table TO STDOUT").gets_to_end
# ```
def copy_out(query : String) : CopyOut
CopyOut.new connection, query
#
# ```
# writer = conn.exec_copy "COPY table FROM STDIN")
# writer << data
# writer.close
# ```
def exec_copy(query : String) : CopyResult
CopyResult.new connection, query
end

# Set the callback block for notices and errors.
Expand Down
52 changes: 0 additions & 52 deletions src/pg/copy_out.cr

This file was deleted.

90 changes: 90 additions & 0 deletions src/pg/copy_result.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# IO object obtained through PG::Connection.exec_copy.
class PG::CopyResult < IO
getter? closed : Bool

def initialize(@connection : PQ::Connection, query : String)
@connection.send_query_message query
response = @connection.expect_frame PQ::Frame::CopyOutResponse | PQ::Frame::CopyInResponse

@reading = response.is_a? PQ::Frame::CopyOutResponse
@frame_size = 0
@end = false
@closed = false
end

private def read_final(done)
return if @end
@end = true

unless done
@connection.skip_bytes @frame_size if @frame_size > 0

while @connection.read_next_copy_start
size = @connection.read_i32 - 4
@connection.skip_bytes size
end
end

@connection.expect_frame PQ::Frame::CommandComplete
@connection.expect_frame PQ::Frame::ReadyForQuery
end

# Returns the number of remaining bytes in the current row.
# Returns 0 the are no more rows to be read.
# This can be used to allocate the precise amount of memory to read a complete row.
#
# ```
# size = io.remaining_row_size
# if size != 0
# row = Bytes.new(size)
# io.read(row)
# # Process the row.
# end
# ```
def remaining_row_size : Int32
raise "Can't read from a write-only PG::CopyResult" unless @reading
check_open

return 0 if @end

if @frame_size == 0
if @connection.read_next_copy_start
@frame_size = @connection.read_i32 - 4
else
read_final true
return 0
end
end

@frame_size
end

def read(slice : Bytes) : Int32
return 0 if slice.empty?

remaining = remaining_row_size
return 0 if remaining == 0

max_bytes = slice.size > remaining ? remaining : slice.size
bytes = @connection.read_direct(slice[0..max_bytes - 1])
@frame_size -= bytes
bytes
end

def write(slice : Bytes) : Nil
raise "Can't write to a read-only PG::CopyResult" if @reading
@connection.send_copy_data_message slice
end

def close : Nil
return if @closed
if @reading
read_final false
else
@connection.send_copy_done_message
@connection.expect_frame PQ::Frame::CommandComplete
@connection.expect_frame PQ::Frame::ReadyForQuery
end
@closed = true
end
end
12 changes: 12 additions & 0 deletions src/pq/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -550,5 +550,17 @@ module PQ
write_chr 'X'
write_i32 4
end

def send_copy_data_message(slice)
write_chr 'd'
write_i32 4 + slice.size
soc.write slice
end

def send_copy_done_message
write_chr 'c'
write_i32 4
soc.flush
end
end
end
4 changes: 4 additions & 0 deletions src/pq/frame.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ module PQ
when 'K' then BackendKeyData
when 'R' then Authentication
when 'c' then CopyDone
when 'G' then CopyInResponse
when 'H' then CopyOutResponse
else nil
end
Expand Down Expand Up @@ -251,6 +252,9 @@ module PQ
struct CopyDone < Frame
end

struct CopyInResponse < Frame
end

struct CopyOutResponse < Frame
end
end
Expand Down

0 comments on commit ed26146

Please sign in to comment.