diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index cc39ffb566c..b7ef7226afd 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5191,8 +5191,13 @@ def _generate_tables_from_shards(shards: list["Dataset"], batch_size: int): @staticmethod def _generate_tables_from_cache_file(filename: str): - for batch_idx, batch in enumerate(_memory_mapped_record_batch_reader_from_file(filename)): - yield batch_idx, pa.Table.from_batches([batch]) + reader, mmap_stream = _memory_mapped_record_batch_reader_from_file(filename) + try: + for batch_idx, batch in enumerate(reader): + yield batch_idx, pa.Table.from_batches([batch]) + finally: + reader.close() + mmap_stream.close() def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset": """Get an [`datasets.IterableDataset`] from a map-style [`datasets.Dataset`]. diff --git a/src/datasets/table.py b/src/datasets/table.py index b9b808ac9a0..bc077e717d1 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -35,6 +35,8 @@ def _in_memory_arrow_table_from_file(filename: str) -> pa.Table: in_memory_stream = pa.input_stream(filename) opened_stream = pa.ipc.open_stream(in_memory_stream) pa_table = opened_stream.read_all() + opened_stream.close() + in_memory_stream.close() return pa_table @@ -42,12 +44,35 @@ def _in_memory_arrow_table_from_buffer(buffer: pa.Buffer) -> pa.Table: stream = pa.BufferReader(buffer) opened_stream = pa.ipc.open_stream(stream) table = opened_stream.read_all() + opened_stream.close() + stream.close() return table -def _memory_mapped_record_batch_reader_from_file(filename: str) -> pa.RecordBatchStreamReader: +def _memory_mapped_record_batch_reader_from_file( + filename: str, +) -> tuple[pa.RecordBatchStreamReader, pa.MemoryMappedFile]: + """ + Creates a memory-mapped record batch reader from a file. + + This function opens a file as a memory-mapped stream and initializes + a RecordBatchStreamReader for reading Arrow record batches from the stream. + + Note: Both the returned RecordBatchStreamReader and MemoryMappedFile + must be explicitly closed after use to release resources. + + Args: + filename (str): The path to the file to be memory-mapped. + + Returns: + tuple[pa.RecordBatchStreamReader, pa.MemoryMappedFile]: + A tuple containing: + - A RecordBatchStreamReader for reading Arrow record batches. + - A MemoryMappedFile object representing the memory-mapped file. + + """ memory_mapped_stream = pa.memory_map(filename) - return pa.ipc.open_stream(memory_mapped_stream) + return pa.ipc.open_stream(memory_mapped_stream), memory_mapped_stream def read_schema_from_file(filename: str) -> pa.Schema: @@ -61,8 +86,10 @@ def read_schema_from_file(filename: str) -> pa.Schema: def _memory_mapped_arrow_table_from_file(filename: str) -> pa.Table: - opened_stream = _memory_mapped_record_batch_reader_from_file(filename) + opened_stream, memory_mapped_stream = _memory_mapped_record_batch_reader_from_file(filename) pa_table = opened_stream.read_all() + opened_stream.close() + memory_mapped_stream.close() return pa_table