diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 8289215de2e29..78c518661e14b 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1743,8 +1743,9 @@ cdef class RecordBatchStream(FlightDataStream): cdef: object data_source CIpcWriteOptions write_options + int64_t max_chunksize - def __init__(self, data_source, options=None): + def __init__(self, data_source, options=None, max_chunksize=None): """Create a RecordBatchStream from a data source. Parameters @@ -1753,6 +1754,9 @@ cdef class RecordBatchStream(FlightDataStream): The data to stream to the client. options : pyarrow.ipc.IpcWriteOptions, optional Optional IPC options to control how to write the data. + max_chunksize : int, default None + Optional maximum number of rows for each chunk. + Only applicable if the data source is a Table. """ if (not isinstance(data_source, RecordBatchReader) and not isinstance(data_source, lib.Table)): @@ -1760,6 +1764,10 @@ cdef class RecordBatchStream(FlightDataStream): "but got: {}".format(type(data_source))) self.data_source = data_source self.write_options = _get_options(options).c_options + if max_chunksize is not None: + self.max_chunksize = max_chunksize + else: + self.max_chunksize = 0 cdef CFlightDataStream* to_stream(self) except *: cdef: @@ -1768,7 +1776,10 @@ cdef class RecordBatchStream(FlightDataStream): reader = ( self.data_source).reader elif isinstance(self.data_source, lib.Table): table = ( self.data_source).table - reader.reset(new TableBatchReader(deref(table))) + batch_reader = new TableBatchReader(deref(table)) + if self.max_chunksize > 0: + batch_reader.set_chunksize(self.max_chunksize) + reader.reset(batch_reader) else: raise RuntimeError("Can't construct RecordBatchStream " "from type {}".format(type(self.data_source)))