15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
- use arrow:: array:: { ArrayIter , ArrayRef , AsArray , Int64Array , RecordBatch , StringArray } ;
19
- use arrow:: compute:: kernels:: cmp:: eq;
18
+ //! This example shows how to create and use "Async UDFs" in DataFusion.
19
+ //!
20
+ //! Async UDFs allow you to perform asynchronous operations, such as
21
+ //! making network requests. This can be used for tasks like fetching
22
+ //! data from an external API such as a LLM service or an external database.
23
+
24
+ use arrow:: array:: { ArrayRef , BooleanArray , Int64Array , RecordBatch , StringArray } ;
20
25
use arrow_schema:: { DataType , Field , Schema } ;
21
26
use async_trait:: async_trait;
27
+ use datafusion:: assert_batches_eq;
28
+ use datafusion:: common:: cast:: as_string_view_array;
22
29
use datafusion:: common:: error:: Result ;
23
- use datafusion:: common:: types :: { logical_int64 , logical_string } ;
30
+ use datafusion:: common:: not_impl_err ;
24
31
use datafusion:: common:: utils:: take_function_args;
25
- use datafusion:: common:: { internal_err, not_impl_err} ;
26
32
use datafusion:: config:: ConfigOptions ;
33
+ use datafusion:: execution:: SessionStateBuilder ;
27
34
use datafusion:: logical_expr:: async_udf:: { AsyncScalarUDF , AsyncScalarUDFImpl } ;
28
35
use datafusion:: logical_expr:: {
29
- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
30
- TypeSignatureClass , Volatility ,
36
+ ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
31
37
} ;
32
- use datafusion:: logical_expr_common:: signature:: Coercion ;
33
- use datafusion:: physical_expr_common:: datum:: apply_cmp;
34
- use datafusion:: prelude:: SessionContext ;
35
- use log:: trace;
38
+ use datafusion:: prelude:: { SessionConfig , SessionContext } ;
36
39
use std:: any:: Any ;
37
40
use std:: sync:: Arc ;
38
41
39
42
#[ tokio:: main]
40
43
async fn main ( ) -> Result < ( ) > {
41
- let ctx: SessionContext = SessionContext :: new ( ) ;
42
-
43
- let async_upper = AsyncUpper :: new ( ) ;
44
- let udf = AsyncScalarUDF :: new ( Arc :: new ( async_upper) ) ;
45
- ctx. register_udf ( udf. into_scalar_udf ( ) ) ;
46
- let async_equal = AsyncEqual :: new ( ) ;
44
+ // Use a hard coded parallelism level of 4 so the explain plan
45
+ // is consistent across machines.
46
+ let config = SessionConfig :: new ( ) . with_target_partitions ( 4 ) ;
47
+ let ctx =
48
+ SessionContext :: from ( SessionStateBuilder :: new ( ) . with_config ( config) . build ( ) ) ;
49
+
50
+ // Similarly to regular UDFs, you create an AsyncScalarUDF by implementing
51
+ // `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`.
52
+ let async_equal = AskLLM :: new ( ) ;
47
53
let udf = AsyncScalarUDF :: new ( Arc :: new ( async_equal) ) ;
54
+
55
+ // Async UDFs are registered with the SessionContext, using the same
56
+ // `register_udf` method as regular UDFs.
48
57
ctx. register_udf ( udf. into_scalar_udf ( ) ) ;
58
+
59
+ // Create a table named 'animal' with some sample data
49
60
ctx. register_batch ( "animal" , animal ( ) ?) ?;
50
61
51
- // use Async UDF in the projection
52
- // +---------------+----------------------------------------------------------------------------------------+
53
- // | plan_type | plan |
54
- // +---------------+----------------------------------------------------------------------------------------+
55
- // | logical_plan | Projection: async_equal(a.id, Int64(1)) |
56
- // | | SubqueryAlias: a |
57
- // | | TableScan: animal projection=[id] |
58
- // | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] |
59
- // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
60
- // | | CoalesceBatchesExec: target_batch_size=8192 |
61
- // | | DataSourceExec: partitions=1, partition_sizes=[1] |
62
- // | | |
63
- // +---------------+----------------------------------------------------------------------------------------+
64
- ctx. sql ( "explain select async_equal(a.id, 1) from animal a" )
62
+ // You can use the async UDF as normal in SQL queries
63
+ //
64
+ // Note: Async UDFs can currently be used in the select list and filter conditions.
65
+ let results = ctx
66
+ . sql ( "select * from animal a where ask_llm(a.name, 'Is this animal furry?')" )
65
67
. await ?
66
- . show ( )
68
+ . collect ( )
67
69
. await ?;
68
70
69
- // +----------------------------+
70
- // | async_equal(a.id,Int64(1)) |
71
- // +----------------------------+
72
- // | true |
73
- // | false |
74
- // | false |
75
- // | false |
76
- // | false |
77
- // +----------------------------+
78
- ctx. sql ( "select async_equal(a.id, 1) from animal a" )
71
+ assert_batches_eq ! (
72
+ [
73
+ "+----+------+" ,
74
+ "| id | name |" ,
75
+ "+----+------+" ,
76
+ "| 1 | cat |" ,
77
+ "| 2 | dog |" ,
78
+ "+----+------+" ,
79
+ ] ,
80
+ & results
81
+ ) ;
82
+
83
+ // While the interface is the same for both normal and async UDFs, you can
84
+ // use `EXPLAIN` output to see that the async UDF uses a special
85
+ // `AsyncFuncExec` node in the physical plan:
86
+ let results = ctx
87
+ . sql ( "explain select * from animal a where ask_llm(a.name, 'Is this animal furry?')" )
79
88
. await ?
80
- . show ( )
89
+ . collect ( )
81
90
. await ?;
82
91
83
- // use Async UDF in the filter
84
- // +---------------+--------------------------------------------------------------------------------------------+
85
- // | plan_type | plan |
86
- // +---------------+--------------------------------------------------------------------------------------------+
87
- // | logical_plan | SubqueryAlias: a |
88
- // | | Filter: async_equal(animal.id, Int64(1)) |
89
- // | | TableScan: animal projection=[id, name] |
90
- // | physical_plan | CoalesceBatchesExec: target_batch_size=8192 |
91
- // | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |
92
- // | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |
93
- // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
94
- // | | CoalesceBatchesExec: target_batch_size=8192 |
95
- // | | DataSourceExec: partitions=1, partition_sizes=[1] |
96
- // | | |
97
- // +---------------+--------------------------------------------------------------------------------------------+
98
- ctx. sql ( "explain select * from animal a where async_equal(a.id, 1)" )
99
- . await ?
100
- . show ( )
101
- . await ?;
102
-
103
- // +----+------+
104
- // | id | name |
105
- // +----+------+
106
- // | 1 | cat |
107
- // +----+------+
108
- ctx. sql ( "select * from animal a where async_equal(a.id, 1)" )
109
- . await ?
110
- . show ( )
111
- . await ?;
92
+ assert_batches_eq ! (
93
+ [
94
+ "+---------------+--------------------------------------------------------------------------------------------------------------------------------+" ,
95
+ "| plan_type | plan |" ,
96
+ "+---------------+--------------------------------------------------------------------------------------------------------------------------------+" ,
97
+ "| logical_plan | SubqueryAlias: a |" ,
98
+ "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\" Is this animal furry?\" )) |" ,
99
+ "| | TableScan: animal projection=[id, name] |" ,
100
+ "| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |" ,
101
+ "| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |" ,
102
+ "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |" ,
103
+ "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |" ,
104
+ "| | CoalesceBatchesExec: target_batch_size=8192 |" ,
105
+ "| | DataSourceExec: partitions=1, partition_sizes=[1] |" ,
106
+ "| | |" ,
107
+ "+---------------+--------------------------------------------------------------------------------------------------------------------------------+" ,
108
+ ] ,
109
+ & results
110
+ ) ;
112
111
113
112
Ok ( ( ) )
114
113
}
115
114
115
+ /// Returns a sample `RecordBatch` representing an "animal" table with two columns:
116
116
fn animal ( ) -> Result < RecordBatch > {
117
117
let schema = Arc :: new ( Schema :: new ( vec ! [
118
118
Field :: new( "id" , DataType :: Int64 , false ) ,
@@ -127,118 +127,45 @@ fn animal() -> Result<RecordBatch> {
127
127
Ok ( RecordBatch :: try_new ( schema, vec ! [ id_array, name_array] ) ?)
128
128
}
129
129
130
+ /// An async UDF that simulates asking a large language model (LLM) service a
131
+ /// question based on the content of two columns. The UDF will return a boolean
132
+ /// indicating whether the LLM thinks the first argument matches the question in
133
+ /// the second argument.
134
+ ///
135
+ /// Since this is a simplified example, it does not call an LLM service, but
136
+ /// could be extended to do so in a real-world scenario.
130
137
#[ derive( Debug ) ]
131
- pub struct AsyncUpper {
132
- signature : Signature ,
133
- }
134
-
135
- impl Default for AsyncUpper {
136
- fn default ( ) -> Self {
137
- Self :: new ( )
138
- }
139
- }
140
-
141
- impl AsyncUpper {
142
- pub fn new ( ) -> Self {
143
- Self {
144
- signature : Signature :: new (
145
- TypeSignature :: Coercible ( vec ! [ Coercion :: Exact {
146
- desired_type: TypeSignatureClass :: Native ( logical_string( ) ) ,
147
- } ] ) ,
148
- Volatility :: Volatile ,
149
- ) ,
150
- }
151
- }
152
- }
153
-
154
- #[ async_trait]
155
- impl ScalarUDFImpl for AsyncUpper {
156
- fn as_any ( & self ) -> & dyn Any {
157
- self
158
- }
159
-
160
- fn name ( & self ) -> & str {
161
- "async_upper"
162
- }
163
-
164
- fn signature ( & self ) -> & Signature {
165
- & self . signature
166
- }
167
-
168
- fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
169
- Ok ( DataType :: Utf8 )
170
- }
171
-
172
- fn invoke_with_args ( & self , _args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
173
- not_impl_err ! ( "AsyncUpper can only be called from async contexts" )
174
- }
175
- }
176
-
177
- #[ async_trait]
178
- impl AsyncScalarUDFImpl for AsyncUpper {
179
- fn ideal_batch_size ( & self ) -> Option < usize > {
180
- Some ( 10 )
181
- }
182
-
183
- async fn invoke_async_with_args (
184
- & self ,
185
- args : ScalarFunctionArgs ,
186
- _option : & ConfigOptions ,
187
- ) -> Result < ArrayRef > {
188
- trace ! ( "Invoking async_upper with args: {:?}" , args) ;
189
- let value = & args. args [ 0 ] ;
190
- let result = match value {
191
- ColumnarValue :: Array ( array) => {
192
- let string_array = array. as_string :: < i32 > ( ) ;
193
- let iter = ArrayIter :: new ( string_array) ;
194
- let result = iter
195
- . map ( |string| string. map ( |s| s. to_uppercase ( ) ) )
196
- . collect :: < StringArray > ( ) ;
197
- Arc :: new ( result) as ArrayRef
198
- }
199
- _ => return internal_err ! ( "Expected a string argument, got {:?}" , value) ,
200
- } ;
201
- Ok ( result)
202
- }
203
- }
204
-
205
- #[ derive( Debug ) ]
206
- struct AsyncEqual {
138
+ struct AskLLM {
207
139
signature : Signature ,
208
140
}
209
141
210
- impl Default for AsyncEqual {
142
+ impl Default for AskLLM {
211
143
fn default ( ) -> Self {
212
144
Self :: new ( )
213
145
}
214
146
}
215
147
216
- impl AsyncEqual {
148
+ impl AskLLM {
217
149
pub fn new ( ) -> Self {
218
150
Self {
219
- signature : Signature :: new (
220
- TypeSignature :: Coercible ( vec ! [
221
- Coercion :: Exact {
222
- desired_type: TypeSignatureClass :: Native ( logical_int64( ) ) ,
223
- } ,
224
- Coercion :: Exact {
225
- desired_type: TypeSignatureClass :: Native ( logical_int64( ) ) ,
226
- } ,
227
- ] ) ,
151
+ signature : Signature :: exact (
152
+ vec ! [ DataType :: Utf8View , DataType :: Utf8View ] ,
228
153
Volatility :: Volatile ,
229
154
) ,
230
155
}
231
156
}
232
157
}
233
158
234
- #[ async_trait]
235
- impl ScalarUDFImpl for AsyncEqual {
159
+ /// All async UDFs implement the `ScalarUDFImpl` trait, which provides the basic
160
+ /// information for the function, such as its name, signature, and return type.
161
+ /// [async_trait]
162
+ impl ScalarUDFImpl for AskLLM {
236
163
fn as_any ( & self ) -> & dyn Any {
237
164
self
238
165
}
239
166
240
167
fn name ( & self ) -> & str {
241
- "async_equal "
168
+ "ask_llm "
242
169
}
243
170
244
171
fn signature ( & self ) -> & Signature {
@@ -249,19 +176,64 @@ impl ScalarUDFImpl for AsyncEqual {
249
176
Ok ( DataType :: Boolean )
250
177
}
251
178
179
+ /// Since this is an async UDF, the `invoke_with_args` method will not be
180
+ /// called directly.
252
181
fn invoke_with_args ( & self , _args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
253
- not_impl_err ! ( "AsyncEqual can only be called from async contexts" )
182
+ not_impl_err ! ( "AskLLM can only be called from async contexts" )
254
183
}
255
184
}
256
185
186
+ /// In addition to [`ScalarUDFImpl`], we also need to implement the
187
+ /// [`AsyncScalarUDFImpl`] trait.
257
188
#[ async_trait]
258
- impl AsyncScalarUDFImpl for AsyncEqual {
189
+ impl AsyncScalarUDFImpl for AskLLM {
190
+ /// The `invoke_async_with_args` method is similar to `invoke_with_args`,
191
+ /// but it returns a `Future` that resolves to the result.
192
+ ///
193
+ /// Since this signature is `async`, it can do any `async` operations, such
194
+ /// as network requests. This method is run on the same tokio `Runtime` that
195
+ /// is processing the query, so you may wish to make actual network requests
196
+ /// on a different `Runtime`, as explained in the `thread_pools.rs` example
197
+ /// in this directory.
259
198
async fn invoke_async_with_args (
260
199
& self ,
261
200
args : ScalarFunctionArgs ,
262
201
_option : & ConfigOptions ,
263
202
) -> Result < ArrayRef > {
264
- let [ arg1, arg2] = take_function_args ( self . name ( ) , & args. args ) ?;
265
- apply_cmp ( arg1, arg2, eq) ?. to_array ( args. number_rows )
203
+ // in a real UDF you would likely want to special case constant
204
+ // arguments to improve performance, but this example converts the
205
+ // arguments to arrays for simplicity.
206
+ let args = ColumnarValue :: values_to_arrays ( & args. args ) ?;
207
+ let [ content_column, question_column] = take_function_args ( self . name ( ) , args) ?;
208
+
209
+ // In a real function, you would use a library such as `reqwest` here to
210
+ // make an async HTTP request. Credentials and other configurations can
211
+ // be supplied via the `ConfigOptions` parameter.
212
+
213
+ // In this example, we will simulate the LLM response by comparing the two
214
+ // input arguments using some static strings
215
+ let content_column = as_string_view_array ( & content_column) ?;
216
+ let question_column = as_string_view_array ( & question_column) ?;
217
+
218
+ let result_array: BooleanArray = content_column
219
+ . iter ( )
220
+ . zip ( question_column. iter ( ) )
221
+ . map ( |( a, b) | {
222
+ // If either value is null, return None
223
+ let a = a?;
224
+ let b = b?;
225
+ // Simulate an LLM response by checking the arguments to some
226
+ // hardcoded conditions.
227
+ if a. contains ( "cat" ) && b. contains ( "furry" )
228
+ || a. contains ( "dog" ) && b. contains ( "furry" )
229
+ {
230
+ Some ( true )
231
+ } else {
232
+ Some ( false )
233
+ }
234
+ } )
235
+ . collect ( ) ;
236
+
237
+ Ok ( Arc :: new ( result_array) )
266
238
}
267
239
}
0 commit comments