Skip to content

Commit 8563cb0

Browse files
committed
Factor out peer access validation
also make peer access invalid when the device and peer device are the same
1 parent 7106845 commit 8563cb0

File tree

1 file changed

+46
-96
lines changed

1 file changed

+46
-96
lines changed

dpctl/_sycl_device.pyx

Lines changed: 46 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,28 @@ cdef void _init_helper(_SyclDevice device, DPCTLSyclDeviceRef DRef) except *:
217217
raise RuntimeError("Descriptor 'max_work_item_sizes3d' not available")
218218

219219

220+
cdef bint _check_peer_access(SyclDevice dev, SyclDevice peer) except *:
221+
"""
222+
Check peer access ahead of time to avoid errors from unified runtime or
223+
compiler implementation.
224+
"""
225+
cdef list _peer_access_backends = [
226+
_backend_type._CUDA,
227+
_backend_type._HIP,
228+
_backend_type._LEVEL_ZERO
229+
]
230+
cdef _backend_type BTy1 = DPCTLDevice_GetBackend(dev._device_ref)
231+
cdef _backend_type BTy2 = DPCTLDevice_GetBackend(peer.get_device_ref())
232+
if (
233+
BTy1 == BTy2 and
234+
BTy1 in _peer_access_backends and
235+
BTy2 in _peer_access_backends and
236+
dev != peer
237+
):
238+
return True
239+
return False
240+
241+
220242
@functools.lru_cache(maxsize=None)
221243
def _cached_filter_string(d : SyclDevice):
222244
"""
@@ -1819,13 +1841,8 @@ cdef class SyclDevice(_SyclDevice):
18191841
Raises:
18201842
TypeError:
18211843
If ``peer`` is not :class:`dpctl.SyclDevice`.
1822-
ValueError:
1823-
If the backend associated with this device or ``peer`` does not
1824-
support peer access.
18251844
"""
18261845
cdef SyclDevice p_dev
1827-
cdef _backend_type BTy1
1828-
cdef _backend_type BTy2
18291846

18301847
if not isinstance(peer, SyclDevice):
18311848
raise TypeError(
@@ -1834,29 +1851,13 @@ cdef class SyclDevice(_SyclDevice):
18341851
)
18351852
p_dev = <SyclDevice>peer
18361853

1837-
_peer_access_backends = [
1838-
_backend_type._CUDA,
1839-
_backend_type._HIP,
1840-
_backend_type._LEVEL_ZERO
1841-
]
1842-
BTy1 = DPCTLDevice_GetBackend(self._device_ref)
1843-
if BTy1 not in _peer_access_backends:
1844-
raise ValueError(
1845-
"Peer access not supported for this device backend "
1846-
f"{_backend_type_to_filter_string_part(BTy1)}"
1854+
if _check_peer_access(self, p_dev):
1855+
return DPCTLDevice_CanAccessPeer(
1856+
self._device_ref,
1857+
p_dev.get_device_ref(),
1858+
_peer_access._access_supported
18471859
)
1848-
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
1849-
if BTy2 not in _peer_access_backends:
1850-
raise ValueError(
1851-
"Peer access not supported for peer device backend "
1852-
f"{_backend_type_to_filter_string_part(BTy2)}"
1853-
)
1854-
1855-
return DPCTLDevice_CanAccessPeer(
1856-
self._device_ref,
1857-
p_dev.get_device_ref(),
1858-
_peer_access._access_supported
1859-
)
1860+
return False
18601861

18611862
def can_access_peer_atomics_supported(self, peer):
18621863
""" Returns ``True`` if this device (``self``) can concurrently access
@@ -1883,13 +1884,8 @@ cdef class SyclDevice(_SyclDevice):
18831884
Raises:
18841885
TypeError:
18851886
If ``peer`` is not :class:`dpctl.SyclDevice`.
1886-
ValueError:
1887-
If the backend associated with this device or ``peer`` does not
1888-
support peer access.
18891887
"""
18901888
cdef SyclDevice p_dev
1891-
cdef _backend_type BTy1
1892-
cdef _backend_type BTy2
18931889

18941890
if not isinstance(peer, SyclDevice):
18951891
raise TypeError(
@@ -1898,29 +1894,13 @@ cdef class SyclDevice(_SyclDevice):
18981894
)
18991895
p_dev = <SyclDevice>peer
19001896

1901-
_peer_access_backends = [
1902-
_backend_type._CUDA,
1903-
_backend_type._HIP,
1904-
_backend_type._LEVEL_ZERO
1905-
]
1906-
BTy1 = DPCTLDevice_GetBackend(self._device_ref)
1907-
if BTy1 not in _peer_access_backends:
1908-
raise ValueError(
1909-
"Peer access not supported for this device backend "
1910-
f"{_backend_type_to_filter_string_part(BTy1)}"
1897+
if _check_peer_access(self, p_dev):
1898+
return DPCTLDevice_CanAccessPeer(
1899+
self._device_ref,
1900+
p_dev.get_device_ref(),
1901+
_peer_access._atomics_supported
19111902
)
1912-
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
1913-
if BTy2 not in _peer_access_backends:
1914-
raise ValueError(
1915-
"Peer access not supported for peer device backend "
1916-
f"{_backend_type_to_filter_string_part(BTy2)}"
1917-
)
1918-
1919-
return DPCTLDevice_CanAccessPeer(
1920-
self._device_ref,
1921-
p_dev.get_device_ref(),
1922-
_peer_access._atomics_supported
1923-
)
1903+
return False
19241904

19251905
def enable_peer_access(self, peer):
19261906
""" Enables this device (``self``) to access USM device allocations
@@ -1944,8 +1924,6 @@ cdef class SyclDevice(_SyclDevice):
19441924
support peer access.
19451925
"""
19461926
cdef SyclDevice p_dev
1947-
cdef _backend_type BTy1
1948-
cdef _backend_type BTy2
19491927

19501928
if not isinstance(peer, SyclDevice):
19511929
raise TypeError(
@@ -1954,27 +1932,13 @@ cdef class SyclDevice(_SyclDevice):
19541932
)
19551933
p_dev = <SyclDevice>peer
19561934

1957-
_peer_access_backends = [
1958-
_backend_type._CUDA,
1959-
_backend_type._HIP,
1960-
_backend_type._LEVEL_ZERO
1961-
]
1962-
BTy1 = (
1963-
DPCTLDevice_GetBackend(self._device_ref)
1964-
)
1965-
if BTy1 not in _peer_access_backends:
1966-
raise ValueError(
1967-
"Peer access not supported for this device backend "
1968-
f"{_backend_type_to_filter_string_part(BTy1)}"
1935+
if _check_peer_access(self, p_dev):
1936+
DPCTLDevice_EnablePeerAccess(
1937+
self._device_ref,
1938+
p_dev.get_device_ref()
19691939
)
1970-
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
1971-
if BTy2 not in _peer_access_backends:
1972-
raise ValueError(
1973-
"Peer access not supported for peer device backend "
1974-
f"{_backend_type_to_filter_string_part(BTy2)}"
1975-
)
1976-
1977-
DPCTLDevice_EnablePeerAccess(self._device_ref, p_dev.get_device_ref())
1940+
else:
1941+
raise ValueError("Peer access cannot be enabled for these devices")
19781942
return
19791943

19801944
def disable_peer_access(self, peer):
@@ -1998,8 +1962,6 @@ cdef class SyclDevice(_SyclDevice):
19981962
support peer access.
19991963
"""
20001964
cdef SyclDevice p_dev
2001-
cdef _backend_type BTy1
2002-
cdef _backend_type BTy2
20031965

20041966
if not isinstance(peer, SyclDevice):
20051967
raise TypeError(
@@ -2008,25 +1970,13 @@ cdef class SyclDevice(_SyclDevice):
20081970
)
20091971
p_dev = <SyclDevice>peer
20101972

2011-
_peer_access_backends = [
2012-
_backend_type._CUDA,
2013-
_backend_type._HIP,
2014-
_backend_type._LEVEL_ZERO
2015-
]
2016-
BTy1 = DPCTLDevice_GetBackend(self._device_ref)
2017-
if BTy1 not in _peer_access_backends:
2018-
raise ValueError(
2019-
"Peer access not supported for this device backend "
2020-
f"{_backend_type_to_filter_string_part(BTy1)}"
1973+
if _check_peer_access(self, p_dev):
1974+
DPCTLDevice_DisablePeerAccess(
1975+
self._device_ref,
1976+
p_dev.get_device_ref()
20211977
)
2022-
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
2023-
if BTy2 not in _peer_access_backends:
2024-
raise ValueError(
2025-
"Peer access not supported for peer device backend "
2026-
f"{_backend_type_to_filter_string_part(BTy2)}"
2027-
)
2028-
2029-
DPCTLDevice_DisablePeerAccess(self._device_ref, p_dev.get_device_ref())
1978+
else:
1979+
raise ValueError("Peer access cannot be enabled for these devices")
20301980
return
20311981

20321982
@property

0 commit comments

Comments
 (0)