File tree Expand file tree Collapse file tree 2 files changed +19
-8
lines changed Expand file tree Collapse file tree 2 files changed +19
-8
lines changed Original file line number Diff line number Diff line change 20
20
import pyarrow as pa
21
21
import pyarrow .dataset as ds
22
22
import pytest
23
- from datafusion import SessionContext , Table
23
+ from datafusion import SessionContext , Table , udtf
24
24
25
25
26
26
# Note we take in `database` as a variable even though we don't use
@@ -232,3 +232,19 @@ def test_in_end_to_end_python_providers(ctx: SessionContext):
232
232
assert len (batches ) == 1
233
233
assert batches [0 ].column (0 ) == pa .array ([1 , 2 , 3 ])
234
234
assert batches [0 ].column (1 ) == pa .array ([4 , 5 , 6 ])
235
+
236
+
237
+ def test_register_python_function_as_udtf (ctx : SessionContext ):
238
+ basic_table = Table (ctx .sql ("SELECT 3 AS value" ))
239
+
240
+ @udtf ("my_table_function" )
241
+ def my_table_function_udtf () -> Table :
242
+ return basic_table
243
+
244
+ ctx .register_udtf (my_table_function_udtf )
245
+
246
+ result = ctx .sql ("SELECT * FROM my_table_function()" ).collect ()
247
+ assert len (result ) == 1
248
+ assert len (result [0 ]) == 1
249
+ assert len (result [0 ][0 ]) == 1
250
+ assert result [0 ][0 ][0 ].as_py () == 3
Original file line number Diff line number Diff line change @@ -21,12 +21,11 @@ use std::sync::Arc;
21
21
use crate :: errors:: { py_datafusion_err, to_datafusion_err} ;
22
22
use crate :: expr:: PyExpr ;
23
23
use crate :: table:: PyTable ;
24
- use crate :: utils:: { table_provider_from_pycapsule , validate_pycapsule} ;
24
+ use crate :: utils:: validate_pycapsule;
25
25
use datafusion:: catalog:: { TableFunctionImpl , TableProvider } ;
26
26
use datafusion:: error:: Result as DataFusionResult ;
27
27
use datafusion:: logical_expr:: Expr ;
28
28
use datafusion_ffi:: udtf:: { FFI_TableFunction , ForeignTableFunction } ;
29
- use pyo3:: exceptions:: PyNotImplementedError ;
30
29
use pyo3:: types:: { PyCapsule , PyTuple } ;
31
30
32
31
/// Represents a user defined table function
@@ -98,11 +97,7 @@ fn call_python_table_function(
98
97
let provider_obj = func. call1 ( py, py_args) ?;
99
98
let provider = provider_obj. bind ( py) ;
100
99
101
- table_provider_from_pycapsule ( provider) ?. ok_or_else ( || {
102
- PyNotImplementedError :: new_err (
103
- "__datafusion_table_provider__ does not exist on Table Provider object." ,
104
- )
105
- } )
100
+ Ok :: < Arc < dyn TableProvider > , PyErr > ( PyTable :: new ( provider) ?. table )
106
101
} )
107
102
. map_err ( to_datafusion_err)
108
103
}
You can’t perform that action at this time.
0 commit comments