Skip to content

Exposing FFI to python #1137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/datafusion-ffi-example/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/datafusion-ffi-example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"]
arrow = { version = "55.0.0" }
arrow-array = { version = "55.0.0" }
arrow-schema = { version = "55.0.0" }
async-trait = "0.1.88"

[build-dependencies]
pyo3-build-config = "0.23"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pyarrow as pa

from datafusion import SessionContext
from datafusion_ffi_example import MyCatalogProvider


def test_catalog_provider():
ctx = SessionContext()

my_catalog_name = "my_catalog"
expected_schema_name = "my_schema"
expected_table_name = "my_table"
expected_table_columns = ["units", "price"]

catalog_provider = MyCatalogProvider()
ctx.register_catalog_provider(my_catalog_name, catalog_provider)
my_catalog = ctx.catalog(my_catalog_name)

my_catalog_schemas = my_catalog.names()
assert expected_schema_name in my_catalog_schemas
my_database = my_catalog.database(expected_schema_name)
assert expected_table_name in my_database.names()
my_table = my_database.table(expected_table_name)
assert expected_table_columns == my_table.schema.names

result = ctx.table(
f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}"
).collect()
assert len(result) == 2

col0_result = [r.column(0) for r in result]
col1_result = [r.column(1) for r in result]
expected_col0 = [
pa.array([10, 20, 30], type=pa.int32()),
pa.array([5, 7], type=pa.int32()),
]
expected_col1 = [
pa.array([1, 2, 5], type=pa.float64()),
pa.array([1.5, 2.5], type=pa.float64()),
]
assert col0_result == expected_col0
assert col1_result == expected_col1
181 changes: 181 additions & 0 deletions examples/datafusion-ffi-example/src/catalog_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{any::Any, fmt::Debug, sync::Arc};
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};

use arrow::datatypes::Schema;
use async_trait::async_trait;
use datafusion::{
catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
TableProvider,
},
common::exec_err,
datasource::MemTable,
error::{DataFusionError, Result},
};
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use pyo3::types::PyCapsule;

pub fn my_table() -> Arc<dyn TableProvider + 'static> {
use arrow::datatypes::{DataType, Field};
use datafusion::common::record_batch;

let schema = Arc::new(Schema::new(vec![
Field::new("units", DataType::Int32, true),
Field::new("price", DataType::Float64, true),
]));

let partitions = vec![
record_batch!(
("units", Int32, vec![10, 20, 30]),
("price", Float64, vec![1.0, 2.0, 5.0])
)
.unwrap(),
record_batch!(
("units", Int32, vec![5, 7]),
("price", Float64, vec![1.5, 2.5])
)
.unwrap(),
];

Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
}

#[derive(Debug)]
pub struct FixedSchemaProvider {
inner: MemorySchemaProvider,
}

impl Default for FixedSchemaProvider {
fn default() -> Self {
let inner = MemorySchemaProvider::new();

let table = my_table();

let _ = inner
.register_table("my_table".to_string(), table)
.unwrap();

Self { inner }
}
}

#[async_trait]
impl SchemaProvider for FixedSchemaProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn table_names(&self) -> Vec<String> {
self.inner.table_names()
}

async fn table(
&self,
name: &str,
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
self.inner.table(name).await
}

fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.register_table(name, table)
}

fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.deregister_table(name)
}

fn table_exist(&self, name: &str) -> bool {
self.inner.table_exist(name)
}
}


/// This catalog provider is intended only for unit tests. It prepopulates with one
/// schema and only allows for schemas named after four types of fruit.
#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)]
#[derive(Debug)]
pub(crate) struct MyCatalogProvider {
inner: MemoryCatalogProvider,
}

impl Default for MyCatalogProvider {
fn default() -> Self {
let inner = MemoryCatalogProvider::new();

let schema_name: &str = "my_schema";
let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default()));

Self { inner }
}
}

impl CatalogProvider for MyCatalogProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema_names(&self) -> Vec<String> {
self.inner.schema_names()
}

fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
self.inner.schema(name)
}

fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
self.inner.register_schema(name, schema)
}

fn deregister_schema(
&self,
name: &str,
cascade: bool,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
self.inner.deregister_schema(name, cascade)
}
}

#[pymethods]
impl MyCatalogProvider {
#[new]
pub fn new() -> Self {
Self {
inner: Default::default(),
}
}

pub fn __datafusion_catalog_provider__<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let name = cr"datafusion_catalog_provider".into();
let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);

PyCapsule::new(py, catalog_provider, Some(name))
}
}
3 changes: 3 additions & 0 deletions examples/datafusion-ffi-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
use crate::catalog_provider::MyCatalogProvider;
use pyo3::prelude::*;

pub(crate) mod table_function;
pub(crate) mod table_provider;
pub(crate) mod catalog_provider;

#[pymodule]
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MyTableProvider>()?;
m.add_class::<MyTableFunction>()?;
m.add_class::<MyCatalogProvider>()?;
Ok(())
}
15 changes: 15 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self) -> object: ... # noqa: D105


class CatalogProviderExportable(Protocol):
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.

https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
"""

def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105


class SessionConfig:
"""Session configuration options."""

Expand Down Expand Up @@ -742,6 +751,12 @@ def deregister_table(self, name: str) -> None:
"""Remove a table from the session."""
self.ctx.deregister_table(name)

def register_catalog_provider(
self, name: str, provider: CatalogProviderExportable
) -> None:
"""Register a catalog provider."""
self.ctx.register_catalog_provider(name, provider)

def register_table_provider(
self, name: str, provider: TableProviderExportable
) -> None:
Expand Down
34 changes: 34 additions & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::CatalogProvider;
use datafusion::common::TableReference;
use datafusion::common::{exec_err, ScalarValue};
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
Expand All @@ -70,6 +71,7 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -582,6 +584,38 @@ impl PySessionContext {
Ok(())
}

pub fn register_catalog_provider(
&mut self,
name: &str,
provider: Bound<'_, PyAny>,
) -> PyDataFusionResult<()> {
if provider.hasattr("__datafusion_catalog_provider__")? {
let capsule = provider.getattr("__datafusion_catalog_provider__")?.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_catalog_provider")?;

let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
let provider: ForeignCatalogProvider = provider.into();

let option: Option<Arc<dyn CatalogProvider>> = self.ctx.register_catalog(name, Arc::new(provider));
match option {
Some(existing) => {
println!("Catalog '{}' already existed, schema names: {:?}", name, existing.schema_names());
}
None => {
println!("Catalog '{}' registered successfully", name);
}
}

Ok(())
} else {
Err(crate::errors::PyDataFusionError::Common(
"__datafusion_catalog_provider__ does not exist on Catalog Provider object."
.to_string(),
))
}
}

/// Construct datafusion dataframe from Arrow Table
pub fn register_table_provider(
&mut self,
Expand Down
Loading