Skip to content

Commit fff48b8

Browse files
authored
Merge pull request #302 from PyO3/borrow-checking-redux
Fixes and refinements to dynamic borrow checking
2 parents 7a5e011 + a013e7b commit fff48b8

File tree

6 files changed

+202
-73
lines changed

6 files changed

+202
-73
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
- Unreleased
44
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
5+
- The deprecated iterator builders `NpySingleIterBuilder::{readonly,readwrite}` and `NpyMultiIterBuilder::add_{readonly,readwrite}` now take referencces to `PyReadonlyArray` and `PyReadwriteArray` instead of consuming them.
6+
- The destructive `PyArray::resize` method is now unsafe if used without an instance of `PyReadwriteArray`. ([#302](https://github.com/PyO3/rust-numpy/pull/302))
57
- The `inner`, `dot` and `einsum` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285))
68
- Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292))
79

src/array.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
2323
use crate::cold;
2424
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2525
use crate::dtype::{Element, PyArrayDescr};
26-
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
26+
use crate::error::{BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError};
2727
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
2828
use crate::slice_container::PySliceContainer;
2929

@@ -846,13 +846,33 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
846846
}
847847

848848
/// Get an immutable borrow of the NumPy array
849+
pub fn try_readonly(&self) -> Result<PyReadonlyArray<'_, T, D>, BorrowError> {
850+
PyReadonlyArray::try_new(self)
851+
}
852+
853+
/// Get an immutable borrow of the NumPy array
854+
///
855+
/// # Panics
856+
///
857+
/// Panics if the allocation backing the array is currently mutably borrowed.
858+
/// For a non-panicking variant, use [`try_readonly`][Self::try_readonly].
849859
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
850-
PyReadonlyArray::try_new(self).unwrap()
860+
self.try_readonly().unwrap()
851861
}
852862

853863
/// Get a mutable borrow of the NumPy array
864+
pub fn try_readwrite(&self) -> Result<PyReadwriteArray<'_, T, D>, BorrowError> {
865+
PyReadwriteArray::try_new(self)
866+
}
867+
868+
/// Get a mutable borrow of the NumPy array
869+
///
870+
/// # Panics
871+
///
872+
/// Panics if the allocation backing the array is currently borrowed.
873+
/// For a non-panicking variant, use [`try_readwrite`][Self::try_readwrite].
854874
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
855-
PyReadwriteArray::try_new(self).unwrap()
875+
self.try_readwrite().unwrap()
856876
}
857877

858878
/// Returns the internal array as [`ArrayView`].
@@ -1057,19 +1077,30 @@ impl<T: Element> PyArray<T, Ix1> {
10571077
data.into_pyarray(py)
10581078
}
10591079

1060-
/// Extends or trancates the length of 1 dimension PyArray.
1080+
/// Extends or truncates the length of a one-dimensional array.
1081+
///
1082+
/// # Safety
1083+
///
1084+
/// There should be no outstanding references (shared or exclusive) into the array
1085+
/// as this method might re-allocate it and thereby invalidate all pointers into it.
10611086
///
10621087
/// # Example
1088+
///
10631089
/// ```
10641090
/// use numpy::PyArray;
1065-
/// pyo3::Python::with_gil(|py| {
1091+
/// use pyo3::Python;
1092+
///
1093+
/// Python::with_gil(|py| {
10661094
/// let pyarray = PyArray::arange(py, 0, 10, 1);
10671095
/// assert_eq!(pyarray.len(), 10);
1068-
/// pyarray.resize(100).unwrap();
1096+
///
1097+
/// unsafe {
1098+
/// pyarray.resize(100).unwrap();
1099+
/// }
10691100
/// assert_eq!(pyarray.len(), 100);
10701101
/// });
10711102
/// ```
1072-
pub fn resize(&self, new_elems: usize) -> PyResult<()> {
1103+
pub unsafe fn resize(&self, new_elems: usize) -> PyResult<()> {
10731104
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
10741105
}
10751106

src/borrow.rs

Lines changed: 129 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,86 @@ impl BorrowFlags {
170170
unsafe fn get(&self) -> &mut HashMap<usize, isize> {
171171
(*self.0.get()).get_or_insert_with(HashMap::new)
172172
}
173+
174+
fn acquire<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError> {
175+
let address = base_address(array);
176+
177+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
178+
// and we are not calling into user code which might re-enter this function.
179+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
180+
181+
match borrow_flags.entry(address) {
182+
Entry::Occupied(entry) => {
183+
let readers = entry.into_mut();
184+
185+
let new_readers = readers.wrapping_add(1);
186+
187+
if new_readers <= 0 {
188+
cold();
189+
return Err(BorrowError::AlreadyBorrowed);
190+
}
191+
192+
*readers = new_readers;
193+
}
194+
Entry::Vacant(entry) => {
195+
entry.insert(1);
196+
}
197+
}
198+
199+
Ok(())
200+
}
201+
202+
fn release<T, D>(&self, array: &PyArray<T, D>) {
203+
let address = base_address(array);
204+
205+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
206+
// and we are not calling into user code which might re-enter this function.
207+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
208+
209+
let readers = borrow_flags.get_mut(&address).unwrap();
210+
211+
*readers -= 1;
212+
213+
if *readers == 0 {
214+
borrow_flags.remove(&address).unwrap();
215+
}
216+
}
217+
218+
fn acquire_mut<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError> {
219+
let address = base_address(array);
220+
221+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
222+
// and we are not calling into user code which might re-enter this function.
223+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
224+
225+
match borrow_flags.entry(address) {
226+
Entry::Occupied(entry) => {
227+
let writers = entry.into_mut();
228+
229+
if *writers != 0 {
230+
cold();
231+
return Err(BorrowError::AlreadyBorrowed);
232+
}
233+
234+
*writers = -1;
235+
}
236+
Entry::Vacant(entry) => {
237+
entry.insert(-1);
238+
}
239+
}
240+
241+
Ok(())
242+
}
243+
244+
fn release_mut<T, D>(&self, array: &PyArray<T, D>) {
245+
let address = base_address(array);
246+
247+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
248+
// and we are not calling into user code which might re-enter this function.
249+
let borrow_flags = unsafe { self.get() };
250+
251+
borrow_flags.remove(&address).unwrap();
252+
}
173253
}
174254

175255
static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
@@ -224,29 +304,7 @@ where
224304
D: Dimension,
225305
{
226306
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
227-
let address = base_address(array);
228-
229-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
230-
// and we are not calling into user code which might re-enter this function.
231-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
232-
233-
match borrow_flags.entry(address) {
234-
Entry::Occupied(entry) => {
235-
let readers = entry.into_mut();
236-
237-
let new_readers = readers.wrapping_add(1);
238-
239-
if new_readers <= 0 {
240-
cold();
241-
return Err(BorrowError::AlreadyBorrowed);
242-
}
243-
244-
*readers = new_readers;
245-
}
246-
Entry::Vacant(entry) => {
247-
entry.insert(1);
248-
}
249-
}
307+
BORROW_FLAGS.acquire(array)?;
250308

251309
Ok(Self(array))
252310
}
@@ -275,21 +333,19 @@ where
275333
}
276334
}
277335

336+
impl<'a, T, D> Clone for PyReadonlyArray<'a, T, D>
337+
where
338+
T: Element,
339+
D: Dimension,
340+
{
341+
fn clone(&self) -> Self {
342+
Self::try_new(self.0).unwrap()
343+
}
344+
}
345+
278346
impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> {
279347
fn drop(&mut self) {
280-
let address = base_address(self.0);
281-
282-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
283-
// and we are not calling into user code which might re-enter this function.
284-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
285-
286-
let readers = borrow_flags.get_mut(&address).unwrap();
287-
288-
*readers -= 1;
289-
290-
if *readers == 0 {
291-
borrow_flags.remove(&address).unwrap();
292-
}
348+
BORROW_FLAGS.release(self.0);
293349
}
294350
}
295351

@@ -348,27 +404,7 @@ where
348404
return Err(BorrowError::NotWriteable);
349405
}
350406

351-
let address = base_address(array);
352-
353-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
354-
// and we are not calling into user code which might re-enter this function.
355-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
356-
357-
match borrow_flags.entry(address) {
358-
Entry::Occupied(entry) => {
359-
let writers = entry.into_mut();
360-
361-
if *writers != 0 {
362-
cold();
363-
return Err(BorrowError::AlreadyBorrowed);
364-
}
365-
366-
*writers = -1;
367-
}
368-
Entry::Vacant(entry) => {
369-
entry.insert(-1);
370-
}
371-
}
407+
BORROW_FLAGS.acquire_mut(array)?;
372408

373409
Ok(Self(array))
374410
}
@@ -397,15 +433,44 @@ where
397433
}
398434
}
399435

400-
impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
401-
fn drop(&mut self) {
402-
let address = base_address(self.0);
436+
impl<'py, T> PyReadwriteArray<'py, T, Ix1>
437+
where
438+
T: Element,
439+
{
440+
/// Extends or truncates the length of a one-dimensional array.
441+
///
442+
/// # Example
443+
///
444+
/// ```
445+
/// use numpy::PyArray;
446+
/// use pyo3::Python;
447+
///
448+
/// Python::with_gil(|py| {
449+
/// let pyarray = PyArray::arange(py, 0, 10, 1);
450+
/// assert_eq!(pyarray.len(), 10);
451+
///
452+
/// let pyarray = pyarray.readwrite();
453+
/// let pyarray = pyarray.resize(100).unwrap();
454+
/// assert_eq!(pyarray.len(), 100);
455+
/// });
456+
/// ```
457+
pub fn resize(self, new_elems: usize) -> PyResult<Self> {
458+
BORROW_FLAGS.release_mut(self.0);
459+
460+
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
461+
unsafe {
462+
self.0.resize(new_elems)?;
463+
}
403464

404-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
405-
// and we are not calling into user code which might re-enter this function.
406-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
465+
BORROW_FLAGS.acquire_mut(self.0)?;
407466

408-
borrow_flags.remove(&address).unwrap();
467+
Ok(self)
468+
}
469+
}
470+
471+
impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
472+
fn drop(&mut self) {
473+
BORROW_FLAGS.release_mut(self.0);
409474
}
410475
}
411476

src/convert.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ use crate::sealed::Sealed;
2929
/// assert_eq!(py_array.readonly().as_slice().unwrap(), &[1, 2, 3]);
3030
///
3131
/// // Array cannot be resized when its data is owned by Rust.
32-
/// assert!(py_array.resize(100).is_err());
32+
/// unsafe {
33+
/// assert!(py_array.resize(100).is_err());
34+
/// }
3335
/// });
3436
/// ```
3537
pub trait IntoPyArray {

tests/borrow.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ fn borrows_span_threads() {
115115
});
116116
}
117117

118+
#[test]
119+
fn shared_borrows_can_be_cloned() {
120+
Python::with_gil(|py| {
121+
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
122+
123+
let shared1 = array.readonly();
124+
let shared2 = shared1.clone();
125+
126+
assert_eq!(shared2.shape(), [1, 2, 3]);
127+
assert_eq!(shared1.shape(), [1, 2, 3]);
128+
});
129+
}
130+
118131
#[test]
119132
#[should_panic(expected = "AlreadyBorrowed")]
120133
fn overlapping_views_conflict() {
@@ -235,3 +248,17 @@ fn readwrite_as_array_slice() {
235248
assert_eq!(*array.get_mut([0, 1, 2]).unwrap(), 0.0);
236249
});
237250
}
251+
252+
#[test]
253+
fn resize_using_exclusive_borrow() {
254+
Python::with_gil(|py| {
255+
let array = PyArray::<f64, _>::zeros(py, 3, false);
256+
assert_eq!(array.shape(), [3]);
257+
258+
let mut array = array.readwrite();
259+
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 3]);
260+
261+
let mut array = array.resize(5).unwrap();
262+
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 5]);
263+
});
264+
}

tests/to_py.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ fn into_pyarray_cannot_resize() {
161161
Python::with_gil(|py| {
162162
let arr = vec![1, 2, 3].into_pyarray(py);
163163

164-
assert!(arr.resize(100).is_err())
164+
unsafe {
165+
assert!(arr.resize(100).is_err());
166+
}
165167
});
166168
}
167169

0 commit comments

Comments
 (0)