diff --git a/spec/pg/connection_spec.cr b/spec/pg/connection_spec.cr index d0e27769..7bb45052 100644 --- a/spec/pg/connection_spec.cr +++ b/spec/pg/connection_spec.cr @@ -142,40 +142,32 @@ describe PG, "#clear_time_zone_cache" do end end -describe PG, "COPY out" do - it "supports COPY TO STDOUT data transfer" do +describe PG, "COPY" do + it "properly handles partial reads and consumes data on early close" do with_connection do |db| - io = db.copy_out "COPY (SELECT 'text', NULL, 1) TO STDOUT" - io.gets_to_end.should eq "text\t\\N\t1\n" + io = db.exec_copy "COPY (VALUES (1), (333)) TO STDOUT" + io.read_char.should eq '1' + io.read_char.should eq '\n' + io.read_char.should eq '3' + io.read_char.should eq '3' io.close db.scalar("select 1").should eq(1) end end - it "propely consumes data on early close" do + if "survives a COPY FROM STDIN and COPY TO STDOUT round-trip" with_connection do |db| - io = db.copy_out "COPY (SELECT * FROM generate_series(1, 100) x) TO STDOUT" - io.gets.should eq "1" - io.gets.should eq "2" - io.gets.should eq "3" - io.close - db.scalar("select 1").should eq(1) - end - end + data = "123\tdata\n\\N\t\\N\n" + db.exec("CREATE TEMPORARY TABLE IF NOT EXISTS copy_test (a int, b text)") - it "properly handles partial reads" do - with_connection do |db| - io = db.copy_out "COPY (VALUES (1), (333)) TO STDOUT" - io.read_char.should eq '1' - io.read_char.should eq '\n' - io.read_char.should eq '3' - io.read_char.should eq '3' - io.read_char.should eq '3' - io.read_char.should eq '\n' - io.read_char.should eq nil - io.read_char.should eq nil - io.close - db.scalar("select 1").should eq(1) + wr = db.exec_copy "COPY copy_test FROM STDIN" + wr << data + wr.close + + rd = db.exec_copy "COPY copy_test TO STDOUT" + rd.gets_to_end.should eq data + + db.exec("DROP TABLE copy_test") end end end diff --git a/src/pg/connection.cr b/src/pg/connection.cr index 476633ba..1f22202a 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -37,16 +37,20 @@ module PG nil end - # Execute a "COPY .. TO STDOUT" query and return an IO object to read from. - # The IO *must* be closed before using the connection again. + # Execute a "COPY" query and return an IO object to read from or write to, + # depending on the query. # # ``` - # io = conn.copy_out "COPY table TO STDOUT" - # data = io.gets_to_end - # io.close + # data = conn.exec_copy("COPY table TO STDOUT").gets_to_end # ``` - def copy_out(query : String) : CopyOut - CopyOut.new connection, query + # + # ``` + # writer = conn.exec_copy "COPY table FROM STDIN") + # writer << data + # writer.close + # ``` + def exec_copy(query : String) : CopyResult + CopyResult.new connection, query end # Set the callback block for notices and errors. diff --git a/src/pg/copy_out.cr b/src/pg/copy_out.cr deleted file mode 100644 index 21b36f19..00000000 --- a/src/pg/copy_out.cr +++ /dev/null @@ -1,52 +0,0 @@ -class PG::CopyOut < IO - getter? closed : Bool - - def initialize(@connection : PQ::Connection, query : String) - @connection.send_query_message query - @connection.expect_frame PQ::Frame::CopyOutResponse - - @frame_size = 0 # Remaining bytes in the current frame - @end = false - @closed = false - end - - def read(slice : Bytes) : Int32 - check_open - - return 0 if slice.empty? - return 0 if @end - - if @frame_size == 0 - if @connection.read_next_copy_start - @frame_size = @connection.read_i32 - 4 - else - @end = true - return 0 - end - end - - max_bytes = slice.size > @frame_size ? @frame_size : slice.size - bytes = @connection.read_direct(slice[0..max_bytes - 1]) - @frame_size -= bytes - bytes - end - - def write(slice : Bytes) : NoReturn - raise "Can't write to PG::CopyOut" - end - - def close : Nil - return if @closed - @closed = true - - unless @end - while @connection.read_next_copy_start - size = @connection.read_i32 - 4 - @connection.skip_bytes size - end - end - - @connection.expect_frame PQ::Frame::CommandComplete - @connection.expect_frame PQ::Frame::ReadyForQuery - end -end diff --git a/src/pg/copy_result.cr b/src/pg/copy_result.cr new file mode 100644 index 00000000..eb6fe153 --- /dev/null +++ b/src/pg/copy_result.cr @@ -0,0 +1,90 @@ +# IO object obtained through PG::Connection.exec_copy. +class PG::CopyResult < IO + getter? closed : Bool + + def initialize(@connection : PQ::Connection, query : String) + @connection.send_query_message query + response = @connection.expect_frame PQ::Frame::CopyOutResponse | PQ::Frame::CopyInResponse + + @reading = response.is_a? PQ::Frame::CopyOutResponse + @frame_size = 0 + @end = false + @closed = false + end + + private def read_final(done) + return if @end + @end = true + + unless done + @connection.skip_bytes @frame_size if @frame_size > 0 + + while @connection.read_next_copy_start + size = @connection.read_i32 - 4 + @connection.skip_bytes size + end + end + + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end + + # Returns the number of remaining bytes in the current row. + # Returns 0 the are no more rows to be read. + # This can be used to allocate the precise amount of memory to read a complete row. + # + # ``` + # size = io.remaining_row_size + # if size != 0 + # row = Bytes.new(size) + # io.read(row) + # # Process the row. + # end + # ``` + def remaining_row_size : Int32 + raise "Can't read from a write-only PG::CopyResult" unless @reading + check_open + + return 0 if @end + + if @frame_size == 0 + if @connection.read_next_copy_start + @frame_size = @connection.read_i32 - 4 + else + read_final true + return 0 + end + end + + @frame_size + end + + def read(slice : Bytes) : Int32 + return 0 if slice.empty? + + remaining = remaining_row_size + return 0 if remaining == 0 + + max_bytes = slice.size > remaining ? remaining : slice.size + bytes = @connection.read_direct(slice[0..max_bytes - 1]) + @frame_size -= bytes + bytes + end + + def write(slice : Bytes) : Nil + raise "Can't write to a read-only PG::CopyResult" if @reading + @connection.send_copy_data_message slice + end + + def close : Nil + return if @closed + if @reading + read_final false + else + @connection.send_copy_done_message + @connection.expect_frame PQ::Frame::CommandComplete + @connection.expect_frame PQ::Frame::ReadyForQuery + end + @closed = true + end +end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index 0e666772..8d6ad888 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -550,5 +550,17 @@ module PQ write_chr 'X' write_i32 4 end + + def send_copy_data_message(slice) + write_chr 'd' + write_i32 4 + slice.size + soc.write slice + end + + def send_copy_done_message + write_chr 'c' + write_i32 4 + soc.flush + end end end diff --git a/src/pq/frame.cr b/src/pq/frame.cr index 7df9173c..4b02b899 100644 --- a/src/pq/frame.cr +++ b/src/pq/frame.cr @@ -22,6 +22,7 @@ module PQ when 'K' then BackendKeyData when 'R' then Authentication when 'c' then CopyDone + when 'G' then CopyInResponse when 'H' then CopyOutResponse else nil end @@ -251,6 +252,9 @@ module PQ struct CopyDone < Frame end + struct CopyInResponse < Frame + end + struct CopyOutResponse < Frame end end