From 1362c83ed5377c84aa30f8a0f394a17d5d0d4a53 Mon Sep 17 00:00:00 2001 From: francis-du Date: Wed, 14 Sep 2022 01:17:38 +0800 Subject: [PATCH 1/2] feat: impl a new Config class --- datafusion/__init__.py | 5 +-- datafusion/tests/test_config.py | 41 +++++++++++++++++ src/config.rs | 80 +++++++++++++++++++++++++++++++++ src/lib.rs | 3 ++ 4 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 datafusion/tests/test_config.py create mode 100644 src/config.rs diff --git a/datafusion/__init__.py b/datafusion/__init__.py index c02e038f9..b2e1028f2 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -23,22 +23,21 @@ except ImportError: import importlib_metadata - import pyarrow as pa from ._internal import ( AggregateUDF, + Config, DataFrame, SessionContext, Expression, ScalarUDF, ) - __version__ = importlib_metadata.version(__name__) - __all__ = [ + "Config", "DataFrame", "SessionContext", "Expression", diff --git a/datafusion/tests/test_config.py b/datafusion/tests/test_config.py new file mode 100644 index 000000000..1e0616111 --- /dev/null +++ b/datafusion/tests/test_config.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import Config +import pytest + + +@pytest.fixture +def config(): + return Config() + + +def test_get_then_set(config): + config_key = "datafusion.optimizer.filter_null_join_keys" + + assert config.get(config_key).as_py() is False + + config.set(config_key, True) + assert config.get(config_key).as_py() is True + + +def test_get_all(config): + config.get_all() + + +def test_get_invalid_config(config): + assert config.get("not.valid.key") is None diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 000000000..664b2598b --- /dev/null +++ b/src/config.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::prelude::*; +use pyo3::types::*; + +use datafusion::config::ConfigOptions; +use datafusion_common::ScalarValue; + +#[pyclass(name = "Config", module = "datafusion", subclass)] +#[derive(Clone)] +pub(crate) struct PyConfig { + config: ConfigOptions, +} + +#[pymethods] +impl PyConfig { + #[new] + fn py_new() -> Self { + Self { + config: ConfigOptions::new(), + } + } + + /// Get configurations from environment variables + #[staticmethod] + pub fn from_env() -> Self { + Self { + config: ConfigOptions::from_env(), + } + } + + /// Get a configuration option + pub fn get(&mut self, key: &str, py: Python) -> PyResult { + Ok(self.config.get(key).into_py(py)) + } + + /// Set a configuration option + pub fn set(&mut self, key: &str, value: PyObject, py: Python) { + self.config.set(key, py_obj_to_scalar_value(py, value)) + } + + /// Get all configuration options + pub fn get_all(&mut self, py: Python) -> PyResult { + let dict = PyDict::new(py); + for (key, value) in self.config.options() { + dict.set_item(key, value.clone().into_py(py))?; + } + Ok(dict.into()) + } +} + +/// Convert a python object to a ScalarValue +fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue { + if let Ok(value) = obj.extract::(py) { + ScalarValue::Boolean(Some(value)) + } else if let Ok(value) = obj.extract::(py) { + ScalarValue::Int64(Some(value)) + } else if let Ok(value) = obj.extract::(py) { + ScalarValue::Float64(Some(value)) + } else if let Ok(value) = obj.extract::(py) { + ScalarValue::Utf8(Some(value)) + } else { + panic!("Unsupported value type") + } +} diff --git a/src/lib.rs b/src/lib.rs index 268dc0434..0423fb131 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,8 @@ use pyo3::prelude::*; #[allow(clippy::borrow_deref_ref)] pub mod catalog; #[allow(clippy::borrow_deref_ref)] +mod config; +#[allow(clippy::borrow_deref_ref)] mod context; #[allow(clippy::borrow_deref_ref)] mod dataframe; @@ -58,6 +60,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; // Register the functions as a submodule let funcs = PyModule::new(py, "functions")?; From 246f331ac896f30a87d84720c1334c5e5b01a58e Mon Sep 17 00:00:00 2001 From: francis-du Date: Sun, 18 Sep 2022 22:59:54 +0800 Subject: [PATCH 2/2] fix: add u64 support for config --- src/config.rs | 2 ++ src/expression.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/config.rs b/src/config.rs index 664b2598b..05a6a0c10 100644 --- a/src/config.rs +++ b/src/config.rs @@ -70,6 +70,8 @@ fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> ScalarValue { ScalarValue::Boolean(Some(value)) } else if let Ok(value) = obj.extract::(py) { ScalarValue::Int64(Some(value)) + } else if let Ok(value) = obj.extract::(py) { + ScalarValue::UInt64(Some(value)) } else if let Ok(value) = obj.extract::(py) { ScalarValue::Float64(Some(value)) } else if let Ok(value) = obj.extract::(py) { diff --git a/src/expression.rs b/src/expression.rs index b40199202..b52699d8d 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -19,7 +19,7 @@ use pyo3::{basic::CompareOp, prelude::*}; use std::convert::{From, Into}; use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::{col, lit, Expr}; +use datafusion_expr::{col, lit, Expr}; use datafusion::scalar::ScalarValue;