Skip to content

Commit 38bb25a

Browse files
committed
Initial implementation of unified table suggestion
1 parent 4429614 commit 38bb25a

File tree

14 files changed

+108
-391
lines changed

14 files changed

+108
-391
lines changed

examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_ffi_aggregate_register():
6060
def test_ffi_aggregate_call_directly():
6161
ctx = setup_context_with_table()
6262
my_udaf = udaf(MySumUDF())
63-
63+
6464
result = (
6565
ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect()
6666
)

python/datafusion/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from . import functions, object_store, substrait, unparser
3535

3636
# The following imports are okay to remain as opaque to the user.
37-
from ._internal import EXPECTED_PROVIDER_MSG, Config
37+
from ._internal import Config
3838
from .catalog import Catalog, Database, Table
3939
from .col import col, column
4040
from .common import DFSchema
@@ -65,7 +65,6 @@
6565
__version__ = importlib_metadata.version(__name__)
6666

6767
__all__ = [
68-
"EXPECTED_PROVIDER_MSG",
6968
"Accumulator",
7069
"AggregateUDF",
7170
"Catalog",

python/datafusion/catalog.py

Lines changed: 12 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,15 @@
1919

2020
from __future__ import annotations
2121

22-
import warnings
2322
from abc import ABC, abstractmethod
2423
from typing import TYPE_CHECKING, Any, Protocol
2524

2625
import datafusion._internal as df_internal
27-
from datafusion._internal import EXPECTED_PROVIDER_MSG
28-
from datafusion.utils import _normalize_table_provider
2926

3027
if TYPE_CHECKING:
3128
import pyarrow as pa
3229

30+
from datafusion import DataFrame
3331
from datafusion.context import TableProviderExportable
3432

3533
try:
@@ -139,8 +137,7 @@ def register_table(
139137
Objects implementing ``__datafusion_table_provider__`` are also supported
140138
and treated as table provider instances.
141139
"""
142-
provider = _normalize_table_provider(table)
143-
return self._raw_schema.register_table(name, provider)
140+
return self._raw_schema.register_table(name, table)
144141

145142
def deregister_table(self, name: str) -> None:
146143
"""Deregister a table provider from this schema."""
@@ -152,101 +149,37 @@ class Database(Schema):
152149
"""See `Schema`."""
153150

154151

155-
_InternalRawTable = df_internal.catalog.RawTable
156-
_InternalTableProvider = df_internal.TableProvider
157-
158-
# Keep in sync with ``datafusion._internal.TableProvider.from_view``.
159-
_FROM_VIEW_WARN_STACKLEVEL = 2
160-
161-
162152
class Table:
163153
"""DataFusion table or table provider wrapper."""
164154

165-
__slots__ = ("_table",)
155+
__slots__ = ("_inner",)
166156

167157
def __init__(
168158
self,
169-
table: _InternalRawTable | _InternalTableProvider | Table,
159+
table: DataFrame | TableProviderExportable | pa.dataset.Dataset,
170160
) -> None:
171161
"""Wrap a low level table or table provider."""
172-
if isinstance(table, Table):
173-
table = table.table
174-
175-
if not isinstance(table, (_InternalRawTable, _InternalTableProvider)):
176-
raise TypeError(EXPECTED_PROVIDER_MSG)
177-
178-
self._table = table
179-
180-
def __getattribute__(self, name: str) -> Any:
181-
"""Restrict provider-specific helpers to compatible tables."""
182-
if name == "__datafusion_table_provider__":
183-
table = object.__getattribute__(self, "_table")
184-
if not hasattr(table, "__datafusion_table_provider__"):
185-
raise AttributeError(name)
186-
return object.__getattribute__(self, name)
162+
self._inner = df_internal.catalog.RawTable(table)
187163

188164
def __repr__(self) -> str:
189165
"""Print a string representation of the table."""
190-
return repr(self._table)
191-
192-
@property
193-
def table(self) -> _InternalRawTable | _InternalTableProvider:
194-
"""Return the wrapped low level table object."""
195-
return self._table
166+
return repr(self._inner)
196167

197-
@classmethod
198-
def from_dataset(cls, dataset: pa.dataset.Dataset) -> Table:
168+
@deprecated("Use Table() constructor instead.")
169+
@staticmethod
170+
def from_dataset(dataset: pa.dataset.Dataset) -> Table:
199171
"""Turn a :mod:`pyarrow.dataset` ``Dataset`` into a :class:`Table`."""
200-
return cls(_InternalRawTable.from_dataset(dataset))
201-
202-
@classmethod
203-
def from_capsule(cls, capsule: Any) -> Table:
204-
"""Create a :class:`Table` from a PyCapsule exported provider."""
205-
provider = _InternalTableProvider.from_capsule(capsule)
206-
return cls(provider)
207-
208-
@classmethod
209-
def from_dataframe(cls, df: Any) -> Table:
210-
"""Create a :class:`Table` from tabular data."""
211-
from datafusion.dataframe import DataFrame as DataFrameWrapper
212-
213-
dataframe = df if isinstance(df, DataFrameWrapper) else DataFrameWrapper(df)
214-
return dataframe.into_view()
215-
216-
@classmethod
217-
def from_view(cls, df: Any) -> Table:
218-
"""Deprecated helper for constructing tables from views."""
219-
from datafusion.dataframe import DataFrame as DataFrameWrapper
220-
221-
if isinstance(df, DataFrameWrapper):
222-
df = df.df
223-
224-
provider = _InternalTableProvider.from_view(df)
225-
warnings.warn(
226-
"Table.from_view is deprecated; use DataFrame.into_view or "
227-
"Table.from_dataframe instead.",
228-
category=DeprecationWarning,
229-
stacklevel=_FROM_VIEW_WARN_STACKLEVEL,
230-
)
231-
return cls(provider)
172+
return Table(dataset)
232173

233174
@property
234175
def schema(self) -> pa.Schema:
235176
"""Returns the schema associated with this table."""
236-
return self._table.schema
177+
return self._inner.schema
237178

238179
@property
239180
def kind(self) -> str:
240181
"""Returns the kind of table."""
241-
return self._table.kind
242-
243-
def __datafusion_table_provider__(self) -> Any:
244-
"""Expose the wrapped provider for FFI integrations."""
245-
exporter = getattr(self._table, "__datafusion_table_provider__", None)
246-
if exporter is None:
247-
msg = "Underlying object does not export __datafusion_table_provider__()"
248-
raise AttributeError(msg)
249-
return exporter()
182+
return self._inner.kind
250183

251184

252185
class CatalogProvider(ABC):

python/datafusion/context.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from datafusion.dataframe import DataFrame
3434
from datafusion.expr import sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
36-
from datafusion.utils import _normalize_table_provider
3736

3837
from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
3938
from ._internal import SessionConfig as SessionConfigInternal
@@ -770,8 +769,7 @@ def register_table(
770769
table: DataFusion :class:`Table` or any object implementing
771770
``__datafusion_table_provider__`` to add to the session context.
772771
"""
773-
provider = _normalize_table_provider(table)
774-
self.ctx.register_table(name, provider)
772+
self.ctx.register_table(name, table)
775773

776774
def deregister_table(self, name: str) -> None:
777775
"""Remove a table from the session."""
@@ -1197,7 +1195,7 @@ def read_table(self, table: Table) -> DataFrame:
11971195
:py:class:`~datafusion.catalog.ListingTable`, create a
11981196
:py:class:`~datafusion.dataframe.DataFrame`.
11991197
"""
1200-
return DataFrame(self.ctx.read_table(table.table))
1198+
return DataFrame(self.ctx.read_table(table._inner))
12011199

12021200
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
12031201
"""Execute the ``plan`` and return the results."""

python/datafusion/utils.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
from __future__ import annotations
2020

2121
from importlib import import_module, util
22-
from typing import TYPE_CHECKING, Any
23-
24-
from datafusion._internal import EXPECTED_PROVIDER_MSG
22+
from typing import Any
2523

2624
_PYARROW_DATASET_TYPES: tuple[type[Any], ...]
2725
_dataset_spec = util.find_spec("pyarrow.dataset")
@@ -37,38 +35,3 @@
3735
if isinstance(value, type) and issubclass(value, dataset_base):
3836
dataset_types.add(value)
3937
_PYARROW_DATASET_TYPES = tuple(dataset_types)
40-
41-
if TYPE_CHECKING: # pragma: no cover - imported for typing only
42-
from datafusion.catalog import Table
43-
from datafusion.context import TableProviderExportable
44-
45-
46-
def _normalize_table_provider(
47-
table: Table | TableProviderExportable | Any,
48-
) -> Any:
49-
"""Return the underlying provider for supported table inputs.
50-
51-
Args:
52-
table: A :class:`~datafusion.Table`, object exporting a DataFusion table
53-
provider via ``__datafusion_table_provider__``, or compatible
54-
:mod:`pyarrow.dataset` implementation.
55-
56-
Returns:
57-
The object expected by the Rust bindings for table registration.
58-
59-
Raises:
60-
TypeError: If ``table`` is not a supported table provider input.
61-
"""
62-
from datafusion.catalog import Table as _Table
63-
64-
if isinstance(table, _Table):
65-
return table.table
66-
67-
if _PYARROW_DATASET_TYPES and isinstance(table, _PYARROW_DATASET_TYPES):
68-
return table
69-
70-
provider_factory = getattr(table, "__datafusion_table_provider__", None)
71-
if callable(provider_factory):
72-
return table
73-
74-
raise TypeError(EXPECTED_PROVIDER_MSG)

python/tests/test_catalog.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pyarrow as pa
2121
import pyarrow.dataset as ds
2222
import pytest
23-
from datafusion import EXPECTED_PROVIDER_MSG, SessionContext, Table
23+
from datafusion import SessionContext, Table
2424

2525

2626
# Note we take in `database` as a variable even though we don't use
@@ -186,16 +186,6 @@ def test_schema_register_table_with_pyarrow_dataset(ctx: SessionContext):
186186
schema.deregister_table(table_name)
187187

188188

189-
def test_schema_register_table_with_dataframe_errors(ctx: SessionContext):
190-
schema = ctx.catalog().schema()
191-
df = ctx.from_pydict({"a": [1]})
192-
193-
with pytest.raises(Exception) as exc_info:
194-
schema.register_table("bad", df)
195-
196-
assert str(exc_info.value) == EXPECTED_PROVIDER_MSG
197-
198-
199189
def test_in_end_to_end_python_providers(ctx: SessionContext):
200190
"""Test registering all python providers and running a query against them."""
201191

python/tests/test_context.py

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import pyarrow.dataset as ds
2323
import pytest
2424
from datafusion import (
25-
EXPECTED_PROVIDER_MSG,
2625
DataFrame,
2726
RuntimeEnvBuilder,
2827
SessionConfig,
@@ -341,20 +340,9 @@ def test_register_table_from_dataframe_into_view(ctx):
341340
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
342341

343342

344-
def test_table_from_capsule(ctx):
345-
df = ctx.from_pydict({"a": [1, 2]})
346-
table = df.into_view()
347-
capsule = table.__datafusion_table_provider__()
348-
table2 = Table.from_capsule(capsule)
349-
assert isinstance(table2, Table)
350-
ctx.register_table("capsule_tbl", table2)
351-
result = ctx.sql("SELECT * FROM capsule_tbl").collect()
352-
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
353-
354-
355343
def test_table_from_dataframe(ctx):
356344
df = ctx.from_pydict({"a": [1, 2]})
357-
table = Table.from_dataframe(df)
345+
table = Table(df)
358346
assert isinstance(table, Table)
359347
ctx.register_table("from_dataframe_tbl", table)
360348
result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect()
@@ -363,42 +351,13 @@ def test_table_from_dataframe(ctx):
363351

364352
def test_table_from_dataframe_internal(ctx):
365353
df = ctx.from_pydict({"a": [1, 2]})
366-
table = Table.from_dataframe(df.df)
354+
table = Table(df.df)
367355
assert isinstance(table, Table)
368356
ctx.register_table("from_internal_dataframe_tbl", table)
369357
result = ctx.sql("SELECT * FROM from_internal_dataframe_tbl").collect()
370358
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
371359

372360

373-
def test_register_table_capsule_direct(ctx):
374-
df = ctx.from_pydict({"a": [1, 2]})
375-
provider = df.into_view()
376-
377-
class CapsuleProvider:
378-
def __init__(self, inner):
379-
self._inner = inner
380-
381-
def __datafusion_table_provider__(self):
382-
return self._inner.__datafusion_table_provider__()
383-
384-
ctx.register_table("capsule_direct_tbl", CapsuleProvider(provider))
385-
result = ctx.sql("SELECT * FROM capsule_direct_tbl").collect()
386-
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
387-
388-
389-
def test_table_from_capsule_invalid():
390-
with pytest.raises(RuntimeError):
391-
Table.from_capsule(object())
392-
393-
394-
def test_register_table_with_dataframe_errors(ctx):
395-
df = ctx.from_pydict({"a": [1]})
396-
with pytest.raises(TypeError) as exc_info:
397-
ctx.register_table("bad", df)
398-
399-
assert str(exc_info.value) == EXPECTED_PROVIDER_MSG
400-
401-
402361
def test_register_dataset(ctx):
403362
# create a RecordBatch and register it as a pyarrow.dataset.Dataset
404363
batch = pa.RecordBatch.from_arrays(

0 commit comments

Comments
 (0)