Skip to content

Commit 32c5b7a

Browse files
committed
switch to Lockable
1 parent 129fda4 commit 32c5b7a

File tree

7 files changed

+89
-74
lines changed

7 files changed

+89
-74
lines changed

pysrc/juliacall/juliapkg.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
"packages": {
44
"PythonCall": {
55
"uuid": "6099a3de-0909-46bc-b1f4-468b9a2dfc0d",
6-
"version": "=0.9.25"
6+
"version": "=0.9.25",
7+
"path": "/Users/mcranmer/PermaDocuments/PythonCall.jl/",
8+
"dev": true
79
},
810
"OpenSSL_jll": {
911
"uuid": "458c3c95-2e84-50aa-8efc-19380b2a3a95",

src/Convert/Convert.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ..Core
99
using ..Core:
1010
C,
1111
Utils,
12+
Lockable,
1213
@autopy,
1314
getptr,
1415
incref,

src/Convert/pyconvert.jl

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ struct PyConvertRule
1212
priority::PyConvertPriority
1313
end
1414

15-
const PYCONVERT_RULES = Dict{String,Vector{PyConvertRule}}()
16-
const PYCONVERT_EXTRATYPES = Py[]
15+
const PYCONVERT_RULES = Lockable(Dict{String,Vector{PyConvertRule}}())
16+
const PYCONVERT_EXTRATYPES = Lockable(Py[])
1717

1818
"""
1919
pyconvert_add_rule(tname::String, T::Type, func::Function, priority::PyConvertPriority=PYCONVERT_PRIORITY_NORMAL)
@@ -69,11 +69,11 @@ function pyconvert_add_rule(
6969
priority::PyConvertPriority = PYCONVERT_PRIORITY_NORMAL,
7070
)
7171
@nospecialize type func
72-
push!(
73-
get!(Vector{PyConvertRule}, PYCONVERT_RULES, pytypename),
72+
Base.@lock PYCONVERT_RULES push!(
73+
get!(Vector{PyConvertRule}, PYCONVERT_RULES[], pytypename),
7474
PyConvertRule(type, func, priority),
7575
)
76-
Base.@lock PYCONVERT_RULES_CACHE_LOCK empty!.(values(PYCONVERT_RULES_CACHE))
76+
Base.@lock PYCONVERT_RULES_CACHE empty!.(values(PYCONVERT_RULES_CACHE[]))
7777
return
7878
end
7979

@@ -163,7 +163,7 @@ function _pyconvert_get_rules(pytype::Py)
163163
omro = collect(pytype.__mro__)
164164
basetypes = Py[pytype]
165165
basemros = Vector{Py}[omro]
166-
for xtype in PYCONVERT_EXTRATYPES
166+
Base.@lock PYCONVERT_EXTRATYPES for xtype in PYCONVERT_EXTRATYPES[]
167167
# find the topmost supertype of
168168
xbase = PyNULL
169169
for base in omro
@@ -248,9 +248,9 @@ function _pyconvert_get_rules(pytype::Py)
248248
mro = String[x for xs in xmro for x in xs]
249249

250250
# get corresponding rules
251-
rules = PyConvertRule[
251+
rules = Base.@lock PYCONVERT_RULES PyConvertRule[
252252
rule for tname in mro for
253-
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES, tname)
253+
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES[], tname)
254254
]
255255

256256
# order the rules by priority, then by original order
@@ -261,11 +261,10 @@ function _pyconvert_get_rules(pytype::Py)
261261
return rules
262262
end
263263

264-
const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}()
265-
const PYCONVERT_PREFERRED_TYPE_LOCK = Threads.SpinLock()
264+
const PYCONVERT_PREFERRED_TYPE = Lockable(Dict{Py,Type}())
266265

267266
pyconvert_preferred_type(pytype::Py) =
268-
Base.@lock PYCONVERT_PREFERRED_TYPE_LOCK get!(PYCONVERT_PREFERRED_TYPE, pytype) do
267+
Base.@lock PYCONVERT_PREFERRED_TYPE get!(PYCONVERT_PREFERRED_TYPE[], pytype) do
269268
if pyissubclass(pytype, pybuiltins.int)
270269
Union{Int,BigInt}
271270
else
@@ -308,11 +307,10 @@ end
308307

309308
pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x)
310309

311-
const PYCONVERT_RULES_CACHE = Dict{Type,Dict{C.PyPtr,Vector{Function}}}()
312-
const PYCONVERT_RULES_CACHE_LOCK = Threads.SpinLock()
310+
const PYCONVERT_RULES_CACHE = Lockable(Dict{Type,Dict{C.PyPtr,Vector{Function}}}())
313311

314312
@generated pyconvert_rules_cache(::Type{T}) where {T} =
315-
Base.@lock PYCONVERT_RULES_CACHE_LOCK get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T)
313+
Base.@lock PYCONVERT_RULES_CACHE get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE[], T)
316314

317315
function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
318316
if T isa Union
@@ -353,12 +351,13 @@ function pytryconvert(::Type{T}, x_) where {T}
353351
# get rules from the cache
354352
# TODO: we should hold weak references and clear the cache if types get deleted
355353
tptr = C.Py_Type(x)
356-
trules = pyconvert_rules_cache(T)
357-
rules = Base.@lock PYCONVERT_RULES_CACHE_LOCK get!(trules, tptr) do
358-
t = pynew(incref(tptr))
359-
ans = pyconvert_get_rules(T, t)::Vector{Function}
360-
pydel!(t)
361-
ans
354+
rules = Base.@lock PYCONVERT_RULES_CACHE let trules = pyconvert_rules_cache(T)
355+
get!(trules, tptr) do
356+
t = pynew(incref(tptr))
357+
ans = pyconvert_get_rules(T, t)::Vector{Function}
358+
pydel!(t)
359+
ans
360+
end
362361
end
363362

364363
# apply the rules
@@ -420,15 +419,17 @@ pyconvertarg(::Type{T}, x, name) where {T} = @autopy x @pyconvert T x_ begin
420419
end
421420

422421
function init_pyconvert()
423-
push!(PYCONVERT_EXTRATYPES, pyimport("io" => "IOBase"))
424-
push!(
425-
PYCONVERT_EXTRATYPES,
426-
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
427-
)
428-
push!(
429-
PYCONVERT_EXTRATYPES,
430-
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
431-
)
422+
Base.@lock PYCONVERT_EXTRATYPES begin
423+
push!(PYCONVERT_EXTRATYPES[], pyimport("io" => "IOBase"))
424+
push!(
425+
PYCONVERT_EXTRATYPES[],
426+
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
427+
)
428+
push!(
429+
PYCONVERT_EXTRATYPES[],
430+
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
431+
)
432+
end
432433

433434
priority = PYCONVERT_PRIORITY_CANONICAL
434435
pyconvert_add_rule("builtins:NoneType", Nothing, pyconvert_rule_none, priority)

src/Core/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const ROOT_DIR = dirname(dirname(@__DIR__))
1111
using ..PythonCall: PythonCall # needed for docstring cross-refs
1212
using ..C: C
1313
using ..GC: GC
14-
using ..Utils: Utils
14+
using ..Utils: Utils, Lockable
1515
using Base: @propagate_inbounds, @kwdef
1616
using Dates:
1717
Date,

src/Core/Py.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x)
5656

5757
Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x)
5858

59-
const PYNULL_CACHE = Py[]
60-
const PYNULL_CACHE_LOCK = Threads.SpinLock()
59+
const PYNULL_CACHE = Lockable(Py[])
6160

6261
"""
6362
pynew([ptr])
@@ -70,11 +69,11 @@ points at, i.e. the new `Py` object owns a reference.
7069
Note that NULL Python objects are not safe in the sense that most API functions will probably
7170
crash your Julia session if you pass a NULL argument.
7271
"""
73-
pynew() = Base.@lock PYNULL_CACHE_LOCK begin
74-
if isempty(PYNULL_CACHE)
72+
pynew() = Base.@lock PYNULL_CACHE begin
73+
if isempty(PYNULL_CACHE[])
7574
Py(Val(:new), C.PyNULL)
7675
else
77-
pop!(PYNULL_CACHE)
76+
pop!(PYNULL_CACHE[])
7877
end
7978
end
8079

@@ -121,7 +120,7 @@ function pydel!(x::Py)
121120
C.Py_DecRef(ptr)
122121
setptr!(x, C.PyNULL)
123122
end
124-
Base.@lock PYNULL_CACHE_LOCK push!(PYNULL_CACHE, x)
123+
Base.@lock PYNULL_CACHE push!(PYNULL_CACHE[], x)
125124
return
126125
end
127126

src/JlWrap/C.jl

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module Cjl
22

33
using ...C: C
4-
using ...Utils: Utils
4+
using ...Utils: Utils, Lockable
55
using Base: @kwdef
66
using UnsafePointers: UnsafePtr
77
using Serialization: serialize, deserialize
@@ -16,13 +16,7 @@ const PyJuliaBase_Type = Ref(C.PyNULL)
1616

1717
# we store the actual julia values here
1818
# the `value` field of `PyJuliaValueObject` indexes into here
19-
const PYJLVALUES = IdDict{Int,Any}()
20-
# unused indices in PYJLVALUES
21-
const PYJLFREEVALUES = Int[]
22-
# Thread safety for PYJLVALUES and PYJLFREEVALUES
23-
const PYJLVALUES_LOCK = Threads.SpinLock()
24-
# Track next available index
25-
const PYJLVALUES_NEXT_IDX = Ref(1)
19+
const PYJLVALUES = Lockable((; values=IdDict{Int,Any}(), free=Int[], next=Ref(1)))
2620

2721
function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
2822
o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
@@ -35,24 +29,23 @@ end
3529
function _pyjl_dealloc(o::C.PyPtr)
3630
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
3731
if idx != 0
38-
Base.@lock PYJLVALUES_LOCK begin
39-
delete!(PYJLVALUES, idx)
40-
push!(PYJLFREEVALUES, idx)
32+
Base.@lock PYJLVALUES begin
33+
delete!(PYJLVALUES[].values, idx)
34+
push!(PYJLVALUES[].free, idx)
4135
end
4236
end
4337
UnsafePtr{PyJuliaValueObject}(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
4438
ccall(UnsafePtr{C.PyTypeObject}(C.Py_Type(o)).free[!], Cvoid, (C.PyPtr,), o)
4539
nothing
4640
end
4741

48-
const PYJLMETHODS = Vector{Any}()
49-
const PYJLMETHODS_LOCK = Threads.SpinLock()
42+
const PYJLMETHODS = Lockable([])
5043

5144
function PyJulia_MethodNum(f)
5245
@nospecialize f
53-
Base.@lock PYJLMETHODS_LOCK begin
54-
push!(PYJLMETHODS, f)
55-
return length(PYJLMETHODS)
46+
Base.@lock PYJLMETHODS begin
47+
push!(PYJLMETHODS[], f)
48+
return length(PYJLMETHODS[])
5649
end
5750
end
5851

@@ -67,13 +60,12 @@ function _pyjl_callmethod(o::C.PyPtr, args::C.PyPtr)
6760
@assert nargs > 0
6861
num = C.PyLong_AsLongLong(C.PyTuple_GetItem(args, 0))
6962
num == -1 && return C.PyNULL
70-
f = Base.@lock PYJLMETHODS_LOCK PYJLMETHODS[num]
63+
f = Base.@lock PYJLMETHODS PYJLMETHODS[][num]
7164
# this form gets defined in jlwrap/base.jl
7265
return _pyjl_callmethod(f, o, args, nargs)::C.PyPtr
7366
end
7467

75-
const PYJLBUFCACHE = Dict{Ptr{Cvoid},Any}()
76-
const PYJLBUFCACHE_LOCK = Threads.SpinLock()
68+
const PYJLBUFCACHE = Lockable(Dict{Ptr{Cvoid},Any}())
7769

7870
@kwdef struct PyBufferInfo{N}
7971
# data
@@ -187,9 +179,7 @@ function _pyjl_get_buffer_impl(
187179

188180
# internal
189181
cptr = Base.pointer_from_objref(c)
190-
Base.@lock PYJLBUFCACHE_LOCK begin
191-
PYJLBUFCACHE[cptr] = c
192-
end
182+
Base.@lock PYJLBUFCACHE PYJLBUFCACHE[][cptr] = c
193183
b.internal[] = cptr
194184

195185
# obj
@@ -207,7 +197,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
207197
C.Py_DecRef(num_)
208198
num == -1 && return Cint(-1)
209199
try
210-
f = Base.@lock PYJLMETHODS_LOCK PYJLMETHODS[num]
200+
f = Base.@lock PYJLMETHODS PYJLMETHODS[][num]
211201
x = PyJuliaValue_GetValue(o)
212202
return _pyjl_get_buffer_impl(o, buf, flags, x, f)::Cint
213203
catch exc
@@ -221,9 +211,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
221211
end
222212

223213
function _pyjl_release_buffer(xo::C.PyPtr, buf::Ptr{C.Py_buffer})
224-
Base.@lock PYJLBUFCACHE_LOCK begin
225-
delete!(PYJLBUFCACHE, UnsafePtr(buf).internal[!])
226-
end
214+
Base.@lock PYJLBUFCACHE delete!(PYJLBUFCACHE[], UnsafePtr(buf).internal[!])
227215
nothing
228216
end
229217

@@ -355,28 +343,26 @@ PyJuliaValue_IsNull(o) = Base.GC.@preserve o UnsafePtr{PyJuliaValueObject}(C.asp
355343

356344
PyJuliaValue_GetValue(o) = Base.GC.@preserve o begin
357345
idx = UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]
358-
Base.@lock PYJLVALUES_LOCK begin
359-
PYJLVALUES[idx]
360-
end
346+
Base.@lock PYJLVALUES PYJLVALUES[].values[idx]
361347
end
362348

363349
PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
364350
o = C.asptr(_o)
365351
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
366352
if idx == 0
367-
Base.@lock PYJLVALUES_LOCK begin
368-
if isempty(PYJLFREEVALUES)
369-
idx = PYJLVALUES_NEXT_IDX[]
370-
PYJLVALUES_NEXT_IDX[] += 1
353+
Base.@lock PYJLVALUES begin
354+
if isempty(PYJLVALUES[].free)
355+
idx = PYJLVALUES[].next[]
356+
PYJLVALUES[].next[] += 1
371357
else
372-
idx = pop!(PYJLFREEVALUES)
358+
idx = pop!(PYJLVALUES[].free)
373359
end
374-
PYJLVALUES[idx] = v
360+
PYJLVALUES[].values[idx] = v
375361
end
376362
UnsafePtr{PyJuliaValueObject}(o).value[] = idx
377363
else
378-
Base.@lock PYJLVALUES_LOCK begin
379-
PYJLVALUES[idx] = v
364+
Base.@lock PYJLVALUES begin
365+
PYJLVALUES[].values[idx] = v
380366
end
381367
end
382368
nothing

src/Utils/Utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,4 +308,30 @@ function Base.iterate(x::StaticString{UInt32,N}, i::Int = 1) where {N}
308308
end
309309
end
310310

311+
@static if !isdefined(Base, :Lockable)
312+
"""
313+
Compat for `Base.Lockable` (introduced in Julia 1.11)
314+
"""
315+
struct Lockable{T, L<:AbstractLock}
316+
value::T
317+
lock::L
318+
end
319+
320+
Lockable(value) = Lockable(value, ReentrantLock())
321+
322+
function Base.lock(f, l::Lockable)
323+
lock(l.lock) do
324+
f(l.value)
325+
end
326+
end
327+
328+
Base.lock(l::Lockable) = lock(l.lock)
329+
Base.trylock(l::Lockable) = trylock(l.lock)
330+
Base.unlock(l::Lockable) = unlock(l.lock)
331+
Base.islocked(l::Lockable) = islocked(l.lock)
332+
Base.getindex(l::Lockable) = (@assert islocked(l); l.value)
333+
else
334+
const Lockable = Base.Lockable
335+
end
336+
311337
end

0 commit comments

Comments
 (0)