Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,20 @@ class TestFlightClient : public ::testing::Test {
std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(client_->DoGet(ticket, &stream));

std::unique_ptr<FlightStreamReader> stream2;
ASSERT_OK(client_->DoGet(ticket, &stream2));
ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));

FlightStreamChunk chunk;
std::shared_ptr<RecordBatch> batch;
for (int i = 0; i < num_batches; ++i) {
ASSERT_OK(stream->Next(&chunk));
ASSERT_OK(reader->ReadNext(&batch));
ASSERT_NE(nullptr, chunk.data);
ASSERT_NE(nullptr, batch);
#if !defined(__MINGW32__)
ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch);
#else
// In MINGW32, the following code does not have the reproducibility at the LSB
// even when this is called twice with the same seed.
Expand All @@ -444,12 +452,15 @@ class TestFlightClient : public ::testing::Test {
// [&dist, &rng] { return static_cast<ValueType>(dist(rng)); });
// /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */
ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data);
ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch);
#endif
}

// Stream exhausted
ASSERT_OK(stream->Next(&chunk));
ASSERT_OK(reader->ReadNext(&batch));
ASSERT_EQ(nullptr, chunk.data);
ASSERT_EQ(nullptr, batch);
}

protected:
Expand Down
36 changes: 36 additions & 0 deletions cpp/src/arrow/flight/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,42 @@ Status MetadataRecordBatchWriter::Begin(const std::shared_ptr<Schema>& schema) {
return Begin(schema, ipc::IpcWriteOptions::Defaults());
}

namespace {
class MetadataRecordBatchReaderAdapter : public RecordBatchReader {
public:
explicit MetadataRecordBatchReaderAdapter(
std::shared_ptr<Schema> schema, std::shared_ptr<MetadataRecordBatchReader> delegate)
: schema_(std::move(schema)), delegate_(std::move(delegate)) {}
std::shared_ptr<Schema> schema() const override { return schema_; }
Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
FlightStreamChunk next;
while (true) {
RETURN_NOT_OK(delegate_->Next(&next));
if (!next.data && !next.app_metadata) {
// EOS
*batch = nullptr;
return Status::OK();
} else if (next.data) {
*batch = std::move(next.data);
return Status::OK();
}
// Got metadata, but no data (which is valid) - read the next message
}
}

private:
std::shared_ptr<Schema> schema_;
std::shared_ptr<MetadataRecordBatchReader> delegate_;
};
}; // namespace

arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
std::shared_ptr<MetadataRecordBatchReader> reader) {
ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
return std::make_shared<MetadataRecordBatchReaderAdapter>(std::move(schema),
std::move(reader));
}

SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights)
: position_(0), flights_(flights) {}

Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader {
virtual Status ReadAll(std::shared_ptr<Table>* table);
};

/// \brief Convert a MetadataRecordBatchReader to a regular RecordBatchReader.
ARROW_FLIGHT_EXPORT
arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
std::shared_ptr<MetadataRecordBatchReader> reader);

/// \brief An interface to write IPC payloads with metadata.
class ARROW_FLIGHT_EXPORT MetadataRecordBatchWriter : public ipc::RecordBatchWriter {
public:
Expand Down
10 changes: 10 additions & 0 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,16 @@ cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin):

return chunk

def to_reader(self):
"""Convert this reader into a regular RecordBatchReader.

This may fail if the schema cannot be read from the remote end.
"""
cdef RecordBatchReader reader
reader = RecordBatchReader.__new__(RecordBatchReader)
reader.reader = GetResultValue(MakeRecordBatchReader(self.reader))
return reader


cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader):
"""The virtual base class for readers for Flight streams."""
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CStatus Next(CFlightStreamChunk* out)
CStatus ReadAll(shared_ptr[CTable]* table)

CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\
" arrow::flight::MakeRecordBatchReader"(
shared_ptr[CMetadataRecordBatchReader])

cdef cppclass CMetadataRecordBatchWriter \
" arrow::flight::MetadataRecordBatchWriter"(CRecordBatchWriter):
CStatus Begin(shared_ptr[CSchema] schema,
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,10 @@ def test_flight_do_get_ints():
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)

# Also test via RecordBatchReader interface
data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all()
assert data.equals(table)

with pytest.raises(flight.FlightServerError,
match="expected IpcWriteOptions, got <class 'int'>"):
with ConstantFlightServer(options=42) as server:
Expand Down