Skip to content

Commit 2da4787

Browse files
authored
Merge pull request #255 from PyO3/unsound-custom-element
Remove unsound custom element example.
2 parents 54b14e0 + fdb366e commit 2da4787

File tree

4 files changed

+58
-44
lines changed

4 files changed

+58
-44
lines changed

src/array.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,23 +629,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
629629
}
630630
}
631631

632-
/// Construct PyArray from
633-
/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html).
632+
/// Constructs a `PyArray` from [`ndarray::Array`]
634633
///
635-
/// This method uses internal [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html)
636-
/// of `ndarray::Array` as numpy array.
634+
/// This method uses the internal [`Vec`] of the `ndarray::Array` as the base object of the NumPy array.
637635
///
638636
/// # Example
637+
///
639638
/// ```
640-
/// # #[macro_use] extern crate ndarray;
639+
/// use ndarray::array;
641640
/// use numpy::PyArray;
641+
///
642642
/// pyo3::Python::with_gil(|py| {
643643
/// let pyarray = PyArray::from_owned_array(py, array![[1, 2], [3, 4]]);
644644
/// assert_eq!(pyarray.readonly().as_array(), array![[1, 2], [3, 4]]);
645645
/// });
646646
/// ```
647647
pub fn from_owned_array<'py>(py: Python<'py>, arr: Array<T, D>) -> &'py Self {
648-
IntoPyArray::into_pyarray(arr, py)
648+
let (strides, dims) = (arr.npy_strides(), arr.raw_dim());
649+
let data_ptr = arr.as_ptr();
650+
unsafe { PyArray::from_raw_parts(py, dims, strides.as_ptr(), data_ptr, arr) }
649651
}
650652

651653
/// Get the immutable reference of the specified element, with checking the passed index is valid.
@@ -858,6 +860,48 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
858860
}
859861
}
860862

863+
impl<D: Dimension> PyArray<PyObject, D> {
864+
/// Constructs a `PyArray` containing objects from [`ndarray::Array`]
865+
///
866+
/// This method uses the internal [`Vec`] of the `ndarray::Array` as the base object the NumPy array.
867+
///
868+
/// # Example
869+
///
870+
/// ```
871+
/// use ndarray::array;
872+
/// use pyo3::{pyclass, Py, Python};
873+
/// use numpy::PyArray;
874+
///
875+
/// #[pyclass]
876+
/// struct CustomElement {
877+
/// foo: i32,
878+
/// bar: f64,
879+
/// }
880+
///
881+
/// Python::with_gil(|py| {
882+
/// let array = array![
883+
/// Py::new(py, CustomElement {
884+
/// foo: 1,
885+
/// bar: 2.0,
886+
/// }).unwrap(),
887+
/// Py::new(py, CustomElement {
888+
/// foo: 3,
889+
/// bar: 4.0,
890+
/// }).unwrap(),
891+
/// ];
892+
///
893+
/// let pyarray = PyArray::from_owned_object_array(py, array);
894+
///
895+
/// assert!(pyarray.readonly().get(0).unwrap().as_ref(py).is_instance::<CustomElement>().unwrap());
896+
/// });
897+
/// ```
898+
pub fn from_owned_object_array<'py, T>(py: Python<'py>, arr: Array<Py<T>, D>) -> &'py Self {
899+
let (strides, dims) = (arr.npy_strides(), arr.raw_dim());
900+
let data_ptr = arr.as_ptr() as *const PyObject;
901+
unsafe { PyArray::from_raw_parts(py, dims, strides.as_ptr(), data_ptr, arr) }
902+
}
903+
}
904+
861905
impl<T: Copy + Element> PyArray<T, Ix0> {
862906
/// Get the element of zero-dimensional PyArray.
863907
///

src/convert.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ where
6363
type Item = A;
6464
type Dim = D;
6565
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
66-
let (strides, dims) = (self.npy_strides(), self.raw_dim());
67-
let data_ptr = self.as_ptr();
68-
unsafe { PyArray::from_raw_parts(py, dims, strides.as_ptr(), data_ptr, self) }
66+
PyArray::from_owned_array(py, self)
6967
}
7068
}
7169

src/dtype.rs

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ impl PyArrayDescr {
292292

293293
/// Represents that a type can be an element of `PyArray`.
294294
///
295-
/// Currently, only integer/float/complex types are supported.
295+
/// Currently, only integer/float/complex/object types are supported.
296296
/// If you come up with a nice implementation for some other types, we're happy to receive your PR :)
297297
/// You may refer to the [numpy document](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types)
298298
/// for all types that numpy supports.
@@ -310,38 +310,12 @@ impl PyArrayDescr {
310310
///
311311
/// # Custom element types
312312
///
313-
/// You can implement this trait to manage arrays of custom element types, but they still need to be stored
314-
/// on Python's heap using PyO3's [Py](pyo3::Py) type.
313+
/// Note that we cannot safely store `Py<T>` where `T: PyClass`, because the type information would be
314+
/// eliminated in the resulting NumPy array.
315+
/// In other words, objects are always treated as `Py<PyAny>` (a.k.a. `PyObject`) by Python code,
316+
/// and only `Py<PyAny>` can be stored in a type safe manner.
315317
///
316-
/// ```
317-
/// use numpy::{ndarray::Array2, Element, PyArray, PyArrayDescr, ToPyArray};
318-
/// use pyo3::{pyclass, Py, Python};
319-
///
320-
/// #[pyclass]
321-
/// pub struct CustomElement;
322-
///
323-
/// // The transparent wrapper is necessary as one cannot implement
324-
/// // a foreign trait (`Element`) on a foreign type (`Py`) directly.
325-
/// #[derive(Clone)]
326-
/// #[repr(transparent)]
327-
/// pub struct Wrapper(pub Py<CustomElement>);
328-
///
329-
/// unsafe impl Element for Wrapper {
330-
/// const IS_COPY: bool = false;
331-
///
332-
/// fn get_dtype(py: Python) -> &PyArrayDescr {
333-
/// PyArrayDescr::object(py)
334-
/// }
335-
/// }
336-
///
337-
/// Python::with_gil(|py| {
338-
/// let array = Array2::<Wrapper>::from_shape_fn((2, 3), |(_i, _j)| {
339-
/// Wrapper(Py::new(py, CustomElement).unwrap())
340-
/// });
341-
///
342-
/// let _array: &PyArray<Wrapper, _> = array.to_pyarray(py);
343-
/// });
344-
/// ```
318+
/// You can however create `ndarray::Array<Py<T>, D>` and turn that into a NumPy array safely and efficiently using [`from_owned_object_array`][crate::PyArray::from_owned_object_array].
345319
pub unsafe trait Element: Clone + Send {
346320
/// Flag that indicates whether this type is trivially copyable.
347321
///

src/slice_container.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ use pyo3::pyclass_slots::PyClassDummySlot;
77
use pyo3::type_object::{LazyStaticType, PyTypeInfo};
88
use pyo3::{ffi, types::PyAny, PyCell};
99

10-
use crate::dtype::Element;
11-
1210
/// Utility type to safely store Box<[_]> or Vec<_> on the Python heap
1311
pub(crate) struct PySliceContainer {
1412
ptr: *mut u8,
@@ -69,7 +67,7 @@ impl<T: Send> From<Vec<T>> for PySliceContainer {
6967

7068
impl<A, D> From<ArrayBase<OwnedRepr<A>, D>> for PySliceContainer
7169
where
72-
A: Element,
70+
A: Send,
7371
D: Dimension,
7472
{
7573
fn from(data: ArrayBase<OwnedRepr<A>, D>) -> Self {

0 commit comments

Comments
 (0)