From f72e0380e9a3b0f9ddb65834c9812bda6a502e88 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 23 Jan 2022 19:10:18 +0100 Subject: [PATCH 1/2] Add a test case showing how to extract an array from a dictionary. --- examples/simple-extension/src/lib.rs | 20 ++++++++++++++++++-- examples/simple-extension/tests/test_ext.py | 8 +++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/examples/simple-extension/src/lib.rs b/examples/simple-extension/src/lib.rs index 54283c9fc..2c1eceecd 100644 --- a/examples/simple-extension/src/lib.rs +++ b/examples/simple-extension/src/lib.rs @@ -1,6 +1,10 @@ use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD}; -use numpy::{Complex64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; -use pyo3::{pymodule, types::PyModule, PyResult, Python}; +use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn}; +use pyo3::{ + pymodule, + types::{PyDict, PyModule}, + PyResult, Python, +}; #[pymodule] fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -52,5 +56,17 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { conj(x.as_array()).into_pyarray(py) } + #[pyfn(m)] + #[pyo3(name = "extract")] + fn extract(d: &PyDict) -> f64 { + let x = d + .get_item("x") + .unwrap() + .downcast::>() + .unwrap(); + + x.readonly().as_array().sum() + } + Ok(()) } diff --git a/examples/simple-extension/tests/test_ext.py b/examples/simple-extension/tests/test_ext.py index 11f7d9b41..4554d5230 100644 --- a/examples/simple-extension/tests/test_ext.py +++ b/examples/simple-extension/tests/test_ext.py @@ -1,5 +1,5 @@ import numpy as np -from rust_ext import axpy, conj, mult +from rust_ext import axpy, conj, mult, extract def test_axpy(): @@ -22,3 +22,9 @@ def test_mult(): def test_conj(): x = np.array([1.0 + 2j, 2.0 + 3j, 3.0 + 4j]) np.testing.assert_array_almost_equal(conj(x), np.conj(x)) + + +def test_extract(): + x = np.arange(5) + d = { "x": x } + np.testing.assert_almost_equal(extract(d), 10.0) From c8390e3179d1bb9ca2e4500d2a01018c292e5c7f Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 23 Jan 2022 19:43:03 +0100 Subject: [PATCH 2/2] Manually implement PyTypeInfo to ensure that downcasting considers dtype and ndim. --- CHANGELOG.md | 1 + examples/simple-extension/tests/test_ext.py | 2 +- src/array.rs | 36 +++++++++++++-------- src/readonly.rs | 2 +- tests/array.rs | 30 +++++++++++++++-- 5 files changed, 53 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa33f0f23..dd31173c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Unreleased - Support object arrays ([#216](https://github.com/PyO3/rust-numpy/pull/216)) - Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216)) + - Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265)) - `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220)) - `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262)) - `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260)) diff --git a/examples/simple-extension/tests/test_ext.py b/examples/simple-extension/tests/test_ext.py index 4554d5230..c163f71a0 100644 --- a/examples/simple-extension/tests/test_ext.py +++ b/examples/simple-extension/tests/test_ext.py @@ -25,6 +25,6 @@ def test_conj(): def test_extract(): - x = np.arange(5) + x = np.arange(5.0) d = { "x": x } np.testing.assert_almost_equal(extract(d), 10.0) diff --git a/src/array.rs b/src/array.rs index c94ff85f9..155f8bd53 100644 --- a/src/array.rs +++ b/src/array.rs @@ -13,9 +13,9 @@ use ndarray::{ }; use num_traits::AsPrimitive; use pyo3::{ - ffi, pyobject_native_type_info, pyobject_native_type_named, type_object, types::PyModule, - AsPyPointer, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject, - PyResult, Python, ToPyObject, + ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject, + IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; @@ -110,16 +110,24 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> { } unsafe impl type_object::PyLayout> for npyffi::PyArrayObject {} + impl type_object::PySizedLayout> for npyffi::PyArrayObject {} -pyobject_native_type_info!( - PyArray, - *npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type), - Some("numpy"), - #checkfunction=npyffi::PyArray_Check - ; T - ; D -); +unsafe impl PyTypeInfo for PyArray { + type AsRefTarget = Self; + + const NAME: &'static str = "PyArray"; + const MODULE: ::std::option::Option<&'static str> = Some("numpy"); + + #[inline] + fn type_object_raw(_py: Python) -> *mut ffi::PyTypeObject { + unsafe { npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type) } + } + + fn is_type_of(ob: &PyAny) -> bool { + <&Self>::extract(ob).is_ok() + } +} pyobject_native_type_named!(PyArray ; T ; D); @@ -129,12 +137,12 @@ impl IntoPy for PyArray { } } -impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray { +impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray { // here we do type-check three times // 1. Checks if the object is PyArray // 2. Checks if the data type of the array is T // 3. Checks if the dimension is same as D - fn extract(ob: &'a PyAny) -> PyResult { + fn extract(ob: &'py PyAny) -> PyResult { let array = unsafe { if npyffi::PyArray_Check(ob.as_ptr()) == 0 { return Err(PyDowncastError::new(ob, "PyArray").into()); @@ -207,7 +215,7 @@ impl PyArray { /// assert!(array.is_contiguous()); /// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py); /// let not_contiguous: &numpy::PyArray1 = py - /// .eval("np.zeros((3, 5))[::2, 4]", Some(locals), None) + /// .eval("np.zeros((3, 5), dtype='float32')[::2, 4]", Some(locals), None) /// .unwrap() /// .downcast() /// .unwrap(); diff --git a/src/readonly.rs b/src/readonly.rs index 6fd2adeb6..7a25a0eb6 100644 --- a/src/readonly.rs +++ b/src/readonly.rs @@ -66,7 +66,7 @@ impl<'py, T: Element, D: Dimension> PyReadonlyArray<'py, T, D> { /// assert_eq!(readonly.as_slice().unwrap(), &[0, 1, 2, 3]); /// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py); /// let not_contiguous: &PyArray1 = py - /// .eval("np.arange(10)[::2]", Some(locals), None) + /// .eval("np.arange(10, dtype='int32')[::2]", Some(locals), None) /// .unwrap() /// .downcast() /// .unwrap(); diff --git a/tests/array.rs b/tests/array.rs index caa1c633a..9d54e8be8 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -2,8 +2,7 @@ use ndarray::*; use numpy::*; use pyo3::{ prelude::*, - types::PyList, - types::{IntoPyDict, PyDict}, + types::{IntoPyDict, PyDict, PyList}, }; fn get_np_locals(py: Python) -> &PyDict { @@ -300,3 +299,30 @@ fn borrow_from_array() { py_run!(py, array, "assert array.shape == (10,)"); }); } + +#[test] +fn downcasting_works() { + Python::with_gil(|py| { + let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]); + + assert!(ob.downcast::>().is_ok()); + }) +} + +#[test] +fn downcasting_respects_element_type() { + Python::with_gil(|py| { + let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]); + + assert!(ob.downcast::>().is_err()); + }) +} + +#[test] +fn downcasting_respects_dimensionality() { + Python::with_gil(|py| { + let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]); + + assert!(ob.downcast::>().is_err()); + }) +}