Skip to content

Commit 2a379e1

Browse files
alambadriangb
authored andcommitted
Improve async_udf example and docs (apache#16846)
* Improve async_udf example and docs * tweak * Remove random monospace async and version note * Fix explain plan diff by hard coding parallelism * rename arguments, use as_string_view_array * request --> reqwest
1 parent c679225 commit 2a379e1

File tree

4 files changed

+178
-192
lines changed

4 files changed

+178
-192
lines changed

datafusion-examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ cargo run --example dataframe
5050
- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF)
5151
- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF)
5252
- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files
53+
- [`async_udf.rs`](examples/async_udf.rs): Define and invoke an asynchronous User Defined Scalar Function (UDF)
5354
- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control)
5455
- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog
5556
- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization

datafusion-examples/examples/async_udf.rs

Lines changed: 137 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -15,104 +15,104 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray};
19-
use arrow::compute::kernels::cmp::eq;
18+
//! This example shows how to create and use "Async UDFs" in DataFusion.
19+
//!
20+
//! Async UDFs allow you to perform asynchronous operations, such as
21+
//! making network requests. This can be used for tasks like fetching
22+
//! data from an external API such as a LLM service or an external database.
23+
24+
use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray};
2025
use arrow_schema::{DataType, Field, Schema};
2126
use async_trait::async_trait;
27+
use datafusion::assert_batches_eq;
28+
use datafusion::common::cast::as_string_view_array;
2229
use datafusion::common::error::Result;
23-
use datafusion::common::types::{logical_int64, logical_string};
30+
use datafusion::common::not_impl_err;
2431
use datafusion::common::utils::take_function_args;
25-
use datafusion::common::{internal_err, not_impl_err};
2632
use datafusion::config::ConfigOptions;
33+
use datafusion::execution::SessionStateBuilder;
2734
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
2835
use datafusion::logical_expr::{
29-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
30-
TypeSignatureClass, Volatility,
36+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
3137
};
32-
use datafusion::logical_expr_common::signature::Coercion;
33-
use datafusion::physical_expr_common::datum::apply_cmp;
34-
use datafusion::prelude::SessionContext;
35-
use log::trace;
38+
use datafusion::prelude::{SessionConfig, SessionContext};
3639
use std::any::Any;
3740
use std::sync::Arc;
3841

3942
#[tokio::main]
4043
async fn main() -> Result<()> {
41-
let ctx: SessionContext = SessionContext::new();
42-
43-
let async_upper = AsyncUpper::new();
44-
let udf = AsyncScalarUDF::new(Arc::new(async_upper));
45-
ctx.register_udf(udf.into_scalar_udf());
46-
let async_equal = AsyncEqual::new();
44+
// Use a hard coded parallelism level of 4 so the explain plan
45+
// is consistent across machines.
46+
let config = SessionConfig::new().with_target_partitions(4);
47+
let ctx =
48+
SessionContext::from(SessionStateBuilder::new().with_config(config).build());
49+
50+
// Similarly to regular UDFs, you create an AsyncScalarUDF by implementing
51+
// `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`.
52+
let async_equal = AskLLM::new();
4753
let udf = AsyncScalarUDF::new(Arc::new(async_equal));
54+
55+
// Async UDFs are registered with the SessionContext, using the same
56+
// `register_udf` method as regular UDFs.
4857
ctx.register_udf(udf.into_scalar_udf());
58+
59+
// Create a table named 'animal' with some sample data
4960
ctx.register_batch("animal", animal()?)?;
5061

51-
// use Async UDF in the projection
52-
// +---------------+----------------------------------------------------------------------------------------+
53-
// | plan_type | plan |
54-
// +---------------+----------------------------------------------------------------------------------------+
55-
// | logical_plan | Projection: async_equal(a.id, Int64(1)) |
56-
// | | SubqueryAlias: a |
57-
// | | TableScan: animal projection=[id] |
58-
// | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] |
59-
// | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
60-
// | | CoalesceBatchesExec: target_batch_size=8192 |
61-
// | | DataSourceExec: partitions=1, partition_sizes=[1] |
62-
// | | |
63-
// +---------------+----------------------------------------------------------------------------------------+
64-
ctx.sql("explain select async_equal(a.id, 1) from animal a")
62+
// You can use the async UDF as normal in SQL queries
63+
//
64+
// Note: Async UDFs can currently be used in the select list and filter conditions.
65+
let results = ctx
66+
.sql("select * from animal a where ask_llm(a.name, 'Is this animal furry?')")
6567
.await?
66-
.show()
68+
.collect()
6769
.await?;
6870

69-
// +----------------------------+
70-
// | async_equal(a.id,Int64(1)) |
71-
// +----------------------------+
72-
// | true |
73-
// | false |
74-
// | false |
75-
// | false |
76-
// | false |
77-
// +----------------------------+
78-
ctx.sql("select async_equal(a.id, 1) from animal a")
71+
assert_batches_eq!(
72+
[
73+
"+----+------+",
74+
"| id | name |",
75+
"+----+------+",
76+
"| 1 | cat |",
77+
"| 2 | dog |",
78+
"+----+------+",
79+
],
80+
&results
81+
);
82+
83+
// While the interface is the same for both normal and async UDFs, you can
84+
// use `EXPLAIN` output to see that the async UDF uses a special
85+
// `AsyncFuncExec` node in the physical plan:
86+
let results = ctx
87+
.sql("explain select * from animal a where ask_llm(a.name, 'Is this animal furry?')")
7988
.await?
80-
.show()
89+
.collect()
8190
.await?;
8291

83-
// use Async UDF in the filter
84-
// +---------------+--------------------------------------------------------------------------------------------+
85-
// | plan_type | plan |
86-
// +---------------+--------------------------------------------------------------------------------------------+
87-
// | logical_plan | SubqueryAlias: a |
88-
// | | Filter: async_equal(animal.id, Int64(1)) |
89-
// | | TableScan: animal projection=[id, name] |
90-
// | physical_plan | CoalesceBatchesExec: target_batch_size=8192 |
91-
// | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |
92-
// | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |
93-
// | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
94-
// | | CoalesceBatchesExec: target_batch_size=8192 |
95-
// | | DataSourceExec: partitions=1, partition_sizes=[1] |
96-
// | | |
97-
// +---------------+--------------------------------------------------------------------------------------------+
98-
ctx.sql("explain select * from animal a where async_equal(a.id, 1)")
99-
.await?
100-
.show()
101-
.await?;
102-
103-
// +----+------+
104-
// | id | name |
105-
// +----+------+
106-
// | 1 | cat |
107-
// +----+------+
108-
ctx.sql("select * from animal a where async_equal(a.id, 1)")
109-
.await?
110-
.show()
111-
.await?;
92+
assert_batches_eq!(
93+
[
94+
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
95+
"| plan_type | plan |",
96+
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
97+
"| logical_plan | SubqueryAlias: a |",
98+
"| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |",
99+
"| | TableScan: animal projection=[id, name] |",
100+
"| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |",
101+
"| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |",
102+
"| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |",
103+
"| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |",
104+
"| | CoalesceBatchesExec: target_batch_size=8192 |",
105+
"| | DataSourceExec: partitions=1, partition_sizes=[1] |",
106+
"| | |",
107+
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
108+
],
109+
&results
110+
);
112111

113112
Ok(())
114113
}
115114

115+
/// Returns a sample `RecordBatch` representing an "animal" table with two columns:
116116
fn animal() -> Result<RecordBatch> {
117117
let schema = Arc::new(Schema::new(vec![
118118
Field::new("id", DataType::Int64, false),
@@ -127,118 +127,45 @@ fn animal() -> Result<RecordBatch> {
127127
Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?)
128128
}
129129

130+
/// An async UDF that simulates asking a large language model (LLM) service a
131+
/// question based on the content of two columns. The UDF will return a boolean
132+
/// indicating whether the LLM thinks the first argument matches the question in
133+
/// the second argument.
134+
///
135+
/// Since this is a simplified example, it does not call an LLM service, but
136+
/// could be extended to do so in a real-world scenario.
130137
#[derive(Debug)]
131-
pub struct AsyncUpper {
132-
signature: Signature,
133-
}
134-
135-
impl Default for AsyncUpper {
136-
fn default() -> Self {
137-
Self::new()
138-
}
139-
}
140-
141-
impl AsyncUpper {
142-
pub fn new() -> Self {
143-
Self {
144-
signature: Signature::new(
145-
TypeSignature::Coercible(vec![Coercion::Exact {
146-
desired_type: TypeSignatureClass::Native(logical_string()),
147-
}]),
148-
Volatility::Volatile,
149-
),
150-
}
151-
}
152-
}
153-
154-
#[async_trait]
155-
impl ScalarUDFImpl for AsyncUpper {
156-
fn as_any(&self) -> &dyn Any {
157-
self
158-
}
159-
160-
fn name(&self) -> &str {
161-
"async_upper"
162-
}
163-
164-
fn signature(&self) -> &Signature {
165-
&self.signature
166-
}
167-
168-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
169-
Ok(DataType::Utf8)
170-
}
171-
172-
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
173-
not_impl_err!("AsyncUpper can only be called from async contexts")
174-
}
175-
}
176-
177-
#[async_trait]
178-
impl AsyncScalarUDFImpl for AsyncUpper {
179-
fn ideal_batch_size(&self) -> Option<usize> {
180-
Some(10)
181-
}
182-
183-
async fn invoke_async_with_args(
184-
&self,
185-
args: ScalarFunctionArgs,
186-
_option: &ConfigOptions,
187-
) -> Result<ArrayRef> {
188-
trace!("Invoking async_upper with args: {:?}", args);
189-
let value = &args.args[0];
190-
let result = match value {
191-
ColumnarValue::Array(array) => {
192-
let string_array = array.as_string::<i32>();
193-
let iter = ArrayIter::new(string_array);
194-
let result = iter
195-
.map(|string| string.map(|s| s.to_uppercase()))
196-
.collect::<StringArray>();
197-
Arc::new(result) as ArrayRef
198-
}
199-
_ => return internal_err!("Expected a string argument, got {:?}", value),
200-
};
201-
Ok(result)
202-
}
203-
}
204-
205-
#[derive(Debug)]
206-
struct AsyncEqual {
138+
struct AskLLM {
207139
signature: Signature,
208140
}
209141

210-
impl Default for AsyncEqual {
142+
impl Default for AskLLM {
211143
fn default() -> Self {
212144
Self::new()
213145
}
214146
}
215147

216-
impl AsyncEqual {
148+
impl AskLLM {
217149
pub fn new() -> Self {
218150
Self {
219-
signature: Signature::new(
220-
TypeSignature::Coercible(vec![
221-
Coercion::Exact {
222-
desired_type: TypeSignatureClass::Native(logical_int64()),
223-
},
224-
Coercion::Exact {
225-
desired_type: TypeSignatureClass::Native(logical_int64()),
226-
},
227-
]),
151+
signature: Signature::exact(
152+
vec![DataType::Utf8View, DataType::Utf8View],
228153
Volatility::Volatile,
229154
),
230155
}
231156
}
232157
}
233158

234-
#[async_trait]
235-
impl ScalarUDFImpl for AsyncEqual {
159+
/// All async UDFs implement the `ScalarUDFImpl` trait, which provides the basic
160+
/// information for the function, such as its name, signature, and return type.
161+
/// [async_trait]
162+
impl ScalarUDFImpl for AskLLM {
236163
fn as_any(&self) -> &dyn Any {
237164
self
238165
}
239166

240167
fn name(&self) -> &str {
241-
"async_equal"
168+
"ask_llm"
242169
}
243170

244171
fn signature(&self) -> &Signature {
@@ -249,19 +176,64 @@ impl ScalarUDFImpl for AsyncEqual {
249176
Ok(DataType::Boolean)
250177
}
251178

179+
/// Since this is an async UDF, the `invoke_with_args` method will not be
180+
/// called directly.
252181
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
253-
not_impl_err!("AsyncEqual can only be called from async contexts")
182+
not_impl_err!("AskLLM can only be called from async contexts")
254183
}
255184
}
256185

186+
/// In addition to [`ScalarUDFImpl`], we also need to implement the
187+
/// [`AsyncScalarUDFImpl`] trait.
257188
#[async_trait]
258-
impl AsyncScalarUDFImpl for AsyncEqual {
189+
impl AsyncScalarUDFImpl for AskLLM {
190+
/// The `invoke_async_with_args` method is similar to `invoke_with_args`,
191+
/// but it returns a `Future` that resolves to the result.
192+
///
193+
/// Since this signature is `async`, it can do any `async` operations, such
194+
/// as network requests. This method is run on the same tokio `Runtime` that
195+
/// is processing the query, so you may wish to make actual network requests
196+
/// on a different `Runtime`, as explained in the `thread_pools.rs` example
197+
/// in this directory.
259198
async fn invoke_async_with_args(
260199
&self,
261200
args: ScalarFunctionArgs,
262201
_option: &ConfigOptions,
263202
) -> Result<ArrayRef> {
264-
let [arg1, arg2] = take_function_args(self.name(), &args.args)?;
265-
apply_cmp(arg1, arg2, eq)?.to_array(args.number_rows)
203+
// in a real UDF you would likely want to special case constant
204+
// arguments to improve performance, but this example converts the
205+
// arguments to arrays for simplicity.
206+
let args = ColumnarValue::values_to_arrays(&args.args)?;
207+
let [content_column, question_column] = take_function_args(self.name(), args)?;
208+
209+
// In a real function, you would use a library such as `reqwest` here to
210+
// make an async HTTP request. Credentials and other configurations can
211+
// be supplied via the `ConfigOptions` parameter.
212+
213+
// In this example, we will simulate the LLM response by comparing the two
214+
// input arguments using some static strings
215+
let content_column = as_string_view_array(&content_column)?;
216+
let question_column = as_string_view_array(&question_column)?;
217+
218+
let result_array: BooleanArray = content_column
219+
.iter()
220+
.zip(question_column.iter())
221+
.map(|(a, b)| {
222+
// If either value is null, return None
223+
let a = a?;
224+
let b = b?;
225+
// Simulate an LLM response by checking the arguments to some
226+
// hardcoded conditions.
227+
if a.contains("cat") && b.contains("furry")
228+
|| a.contains("dog") && b.contains("furry")
229+
{
230+
Some(true)
231+
} else {
232+
Some(false)
233+
}
234+
})
235+
.collect();
236+
237+
Ok(Arc::new(result_array))
266238
}
267239
}

datafusion/core/src/execution/context/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ where
226226
/// # use datafusion::execution::SessionStateBuilder;
227227
/// # use datafusion_execution::runtime_env::RuntimeEnvBuilder;
228228
/// // Configure a 4k batch size
229-
/// let config = SessionConfig::new() .with_batch_size(4 * 1024);
229+
/// let config = SessionConfig::new().with_batch_size(4 * 1024);
230230
///
231231
/// // configure a memory limit of 1GB with 20% slop
232232
/// let runtime_env = RuntimeEnvBuilder::new()

0 commit comments

Comments
 (0)