Skip to content

Commit 36084a0

Browse files
committed
Reuse Table constructor to idenfity non-ffi tables when using udtf
1 parent 81b46cb commit 36084a0

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

python/tests/test_catalog.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pyarrow as pa
2121
import pyarrow.dataset as ds
2222
import pytest
23-
from datafusion import SessionContext, Table
23+
from datafusion import SessionContext, Table, udtf
2424

2525

2626
# Note we take in `database` as a variable even though we don't use
@@ -232,3 +232,19 @@ def test_in_end_to_end_python_providers(ctx: SessionContext):
232232
assert len(batches) == 1
233233
assert batches[0].column(0) == pa.array([1, 2, 3])
234234
assert batches[0].column(1) == pa.array([4, 5, 6])
235+
236+
237+
def test_register_python_function_as_udtf(ctx: SessionContext):
238+
basic_table = Table(ctx.sql("SELECT 3 AS value"))
239+
240+
@udtf("my_table_function")
241+
def my_table_function_udtf() -> Table:
242+
return basic_table
243+
244+
ctx.register_udtf(my_table_function_udtf)
245+
246+
result = ctx.sql("SELECT * FROM my_table_function()").collect()
247+
assert len(result) == 1
248+
assert len(result[0]) == 1
249+
assert len(result[0][0]) == 1
250+
assert result[0][0][0].as_py() == 3

src/udtf.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ use std::sync::Arc;
2121
use crate::errors::{py_datafusion_err, to_datafusion_err};
2222
use crate::expr::PyExpr;
2323
use crate::table::PyTable;
24-
use crate::utils::{table_provider_from_pycapsule, validate_pycapsule};
24+
use crate::utils::validate_pycapsule;
2525
use datafusion::catalog::{TableFunctionImpl, TableProvider};
2626
use datafusion::error::Result as DataFusionResult;
2727
use datafusion::logical_expr::Expr;
2828
use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
29-
use pyo3::exceptions::PyNotImplementedError;
3029
use pyo3::types::{PyCapsule, PyTuple};
3130

3231
/// Represents a user defined table function
@@ -98,11 +97,7 @@ fn call_python_table_function(
9897
let provider_obj = func.call1(py, py_args)?;
9998
let provider = provider_obj.bind(py);
10099

101-
table_provider_from_pycapsule(provider)?.ok_or_else(|| {
102-
PyNotImplementedError::new_err(
103-
"__datafusion_table_provider__ does not exist on Table Provider object.",
104-
)
105-
})
100+
Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider)?.table)
106101
})
107102
.map_err(to_datafusion_err)
108103
}

0 commit comments

Comments
 (0)