Skip to content

Commit 2eca398

Browse files
committed
fix: lock various methods
1 parent c84c059 commit 2eca398

File tree

4 files changed

+48
-22
lines changed

4 files changed

+48
-22
lines changed

src/Convert/pyconvert.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function pyconvert_add_rule(
7373
get!(Vector{PyConvertRule}, PYCONVERT_RULES, pytypename),
7474
PyConvertRule(type, func, priority),
7575
)
76-
empty!.(values(PYCONVERT_RULES_CACHE))
76+
Base.@lock PYCONVERT_RULES_CACHE_LOCK empty!.(values(PYCONVERT_RULES_CACHE))
7777
return
7878
end
7979

@@ -262,9 +262,10 @@ function _pyconvert_get_rules(pytype::Py)
262262
end
263263

264264
const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}()
265+
const PYCONVERT_PREFERRED_TYPE_LOCK = Threads.SpinLock()
265266

266267
pyconvert_preferred_type(pytype::Py) =
267-
get!(PYCONVERT_PREFERRED_TYPE, pytype) do
268+
Base.@lock PYCONVERT_PREFERRED_TYPE_LOCK get!(PYCONVERT_PREFERRED_TYPE, pytype) do
268269
if pyissubclass(pytype, pybuiltins.int)
269270
Union{Int,BigInt}
270271
else
@@ -308,9 +309,10 @@ end
308309
pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x)
309310

310311
const PYCONVERT_RULES_CACHE = Dict{Type,Dict{C.PyPtr,Vector{Function}}}()
312+
const PYCONVERT_RULES_CACHE_LOCK = Threads.SpinLock()
311313

312314
@generated pyconvert_rules_cache(::Type{T}) where {T} =
313-
get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T)
315+
Base.@lock PYCONVERT_RULES_CACHE_LOCK get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T)
314316

315317
function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
316318
if T isa Union
@@ -352,7 +354,7 @@ function pytryconvert(::Type{T}, x_) where {T}
352354
# TODO: we should hold weak references and clear the cache if types get deleted
353355
tptr = C.Py_Type(x)
354356
trules = pyconvert_rules_cache(T)
355-
rules = get!(trules, tptr) do
357+
rules = Base.@lock PYCONVERT_RULES_CACHE_LOCK get!(trules, tptr) do
356358
t = pynew(incref(tptr))
357359
ans = pyconvert_get_rules(T, t)::Vector{Function}
358360
pydel!(t)

src/Core/Py.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x)
5757
Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x)
5858

5959
const PYNULL_CACHE = Py[]
60+
const PYNULL_CACHE_LOCK = Threads.SpinLock()
6061

6162
"""
6263
pynew([ptr])
@@ -69,12 +70,13 @@ points at, i.e. the new `Py` object owns a reference.
6970
Note that NULL Python objects are not safe in the sense that most API functions will probably
7071
crash your Julia session if you pass a NULL argument.
7172
"""
72-
pynew() =
73+
pynew() = Base.@lock PYNULL_CACHE_LOCK begin
7374
if isempty(PYNULL_CACHE)
7475
Py(Val(:new), C.PyNULL)
7576
else
7677
pop!(PYNULL_CACHE)
7778
end
79+
end
7880

7981
const PyNULL = pynew()
8082

@@ -119,7 +121,7 @@ function pydel!(x::Py)
119121
C.Py_DecRef(ptr)
120122
setptr!(x, C.PyNULL)
121123
end
122-
push!(PYNULL_CACHE, x)
124+
Base.@lock PYNULL_CACHE_LOCK push!(PYNULL_CACHE, x)
123125
return
124126
end
125127

src/JlWrap/C.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ const PyJuliaBase_Type = Ref(C.PyNULL)
1919
const PYJLVALUES = []
2020
# unused indices in PYJLVALUES
2121
const PYJLFREEVALUES = Int[]
22+
# Thread safety for PYJLVALUES and PYJLFREEVALUES
23+
const PYJLVALUES_LOCK = Threads.SpinLock()
2224

2325
function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
2426
o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
@@ -31,20 +33,25 @@ end
3133
function _pyjl_dealloc(o::C.PyPtr)
3234
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
3335
if idx != 0
34-
PYJLVALUES[idx] = nothing
35-
push!(PYJLFREEVALUES, idx)
36+
Base.@lock PYJLVALUES_LOCK begin
37+
PYJLVALUES[idx] = nothing
38+
push!(PYJLFREEVALUES, idx)
39+
end
3640
end
3741
UnsafePtr{PyJuliaValueObject}(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
3842
ccall(UnsafePtr{C.PyTypeObject}(C.Py_Type(o)).free[!], Cvoid, (C.PyPtr,), o)
3943
nothing
4044
end
4145

4246
const PYJLMETHODS = Vector{Any}()
47+
const PYJLMETHODS_LOCK = Threads.SpinLock()
4348

4449
function PyJulia_MethodNum(f)
4550
@nospecialize f
46-
push!(PYJLMETHODS, f)
47-
return length(PYJLMETHODS)
51+
Base.@lock PYJLMETHODS_LOCK begin
52+
push!(PYJLMETHODS, f)
53+
return length(PYJLMETHODS)
54+
end
4855
end
4956

5057
function _pyjl_isnull(o::C.PyPtr, ::C.PyPtr)
@@ -58,12 +65,13 @@ function _pyjl_callmethod(o::C.PyPtr, args::C.PyPtr)
5865
@assert nargs > 0
5966
num = C.PyLong_AsLongLong(C.PyTuple_GetItem(args, 0))
6067
num == -1 && return C.PyNULL
61-
f = PYJLMETHODS[num]
68+
f = Base.@lock PYJLMETHODS_LOCK PYJLMETHODS[num]
6269
# this form gets defined in jlwrap/base.jl
6370
return _pyjl_callmethod(f, o, args, nargs)::C.PyPtr
6471
end
6572

6673
const PYJLBUFCACHE = Dict{Ptr{Cvoid},Any}()
74+
const PYJLBUFCACHE_LOCK = Threads.SpinLock()
6775

6876
@kwdef struct PyBufferInfo{N}
6977
# data
@@ -177,7 +185,9 @@ function _pyjl_get_buffer_impl(
177185

178186
# internal
179187
cptr = Base.pointer_from_objref(c)
180-
PYJLBUFCACHE[cptr] = c
188+
Base.@lock PYJLBUFCACHE_LOCK begin
189+
PYJLBUFCACHE[cptr] = c
190+
end
181191
b.internal[] = cptr
182192

183193
# obj
@@ -195,7 +205,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
195205
C.Py_DecRef(num_)
196206
num == -1 && return Cint(-1)
197207
try
198-
f = PYJLMETHODS[num]
208+
f = Base.@lock PYJLMETHODS_LOCK PYJLMETHODS[num]
199209
x = PyJuliaValue_GetValue(o)
200210
return _pyjl_get_buffer_impl(o, buf, flags, x, f)::Cint
201211
catch exc
@@ -209,7 +219,9 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
209219
end
210220

211221
function _pyjl_release_buffer(xo::C.PyPtr, buf::Ptr{C.Py_buffer})
212-
delete!(PYJLBUFCACHE, UnsafePtr(buf).internal[!])
222+
Base.@lock PYJLBUFCACHE_LOCK begin
223+
delete!(PYJLBUFCACHE, UnsafePtr(buf).internal[!])
224+
end
213225
nothing
214226
end
215227

@@ -339,22 +351,31 @@ end
339351

340352
PyJuliaValue_IsNull(o) = Base.GC.@preserve o UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[] == 0
341353

342-
PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]]
354+
PyJuliaValue_GetValue(o) = Base.GC.@preserve o begin
355+
idx = UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]
356+
Base.@lock PYJLVALUES_LOCK begin
357+
PYJLVALUES[idx]
358+
end
359+
end
343360

344361
PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
345362
o = C.asptr(_o)
346363
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
347364
if idx == 0
348-
if isempty(PYJLFREEVALUES)
349-
push!(PYJLVALUES, v)
350-
idx = length(PYJLVALUES)
351-
else
352-
idx = pop!(PYJLFREEVALUES)
353-
PYJLVALUES[idx] = v
365+
Base.@lock PYJLVALUES_LOCK begin
366+
if isempty(PYJLFREEVALUES)
367+
push!(PYJLVALUES, v)
368+
idx = length(PYJLVALUES)
369+
else
370+
idx = pop!(PYJLFREEVALUES)
371+
PYJLVALUES[idx] = v
372+
end
354373
end
355374
UnsafePtr{PyJuliaValueObject}(o).value[] = idx
356375
else
357-
PYJLVALUES[idx] = v
376+
Base.@lock PYJLVALUES_LOCK begin
377+
PYJLVALUES[idx] = v
378+
end
358379
end
359380
nothing
360381
end

src/JlWrap/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ function Cjl._pyjl_callmethod(f, self_::C.PyPtr, args_::C.PyPtr, nargs::C.Py_ssi
8484
pybuiltins.NotImplementedError,
8585
"__jl_callmethod not implemented for this many arguments",
8686
)
87+
return C.PyNULL
8788
end
8889
return getptr(incref(ans))
8990
catch exc

0 commit comments

Comments
 (0)