diff --git a/sycl/source/detail/sampler_impl.cpp b/sycl/source/detail/sampler_impl.cpp index 6ef03918f2862..243c8eb1fe9ec 100644 --- a/sycl/source/detail/sampler_impl.cpp +++ b/sycl/source/detail/sampler_impl.cpp @@ -16,28 +16,29 @@ namespace detail { sampler_impl::sampler_impl(coordinate_normalization_mode normalizationMode, addressing_mode addressingMode, filtering_mode filteringMode) - : m_CoordNormMode(normalizationMode), m_AddrMode(addressingMode), - m_FiltMode(filteringMode) {} + : MCoordNormMode(normalizationMode), MAddrMode(addressingMode), + MFiltMode(filteringMode) {} sampler_impl::sampler_impl(cl_sampler clSampler, const context &syclContext) { RT::PiSampler Sampler = pi::cast(clSampler); - m_contextToSampler[syclContext] = Sampler; + MContextToSampler[syclContext] = Sampler; const detail::plugin &Plugin = getSyclObjImpl(syclContext)->getPlugin(); Plugin.call(Sampler); Plugin.call( Sampler, PI_SAMPLER_INFO_NORMALIZED_COORDS, sizeof(pi_bool), - &m_CoordNormMode, nullptr); + &MCoordNormMode, nullptr); Plugin.call( Sampler, PI_SAMPLER_INFO_ADDRESSING_MODE, - sizeof(pi_sampler_addressing_mode), &m_AddrMode, nullptr); + sizeof(pi_sampler_addressing_mode), &MAddrMode, nullptr); Plugin.call(Sampler, PI_SAMPLER_INFO_FILTER_MODE, sizeof(pi_sampler_filter_mode), - &m_FiltMode, nullptr); + &MFiltMode, nullptr); } sampler_impl::~sampler_impl() { - for (auto &Iter : m_contextToSampler) { + std::lock_guard Lock(MMutex); + for (auto &Iter : MContextToSampler) { // TODO catch an exception and add it to the list of asynchronous exceptions const detail::plugin &Plugin = getSyclObjImpl(Iter.first)->getPlugin(); Plugin.call(Iter.second); @@ -45,16 +46,20 @@ sampler_impl::~sampler_impl() { } RT::PiSampler sampler_impl::getOrCreateSampler(const context &Context) { - if (m_contextToSampler[Context]) - return m_contextToSampler[Context]; + { + std::lock_guard Lock(MMutex); + auto It = MContextToSampler.find(Context); + if (It != MContextToSampler.end()) + return It->second; + } const pi_sampler_properties sprops[] = { PI_SAMPLER_INFO_NORMALIZED_COORDS, - static_cast(m_CoordNormMode), + static_cast(MCoordNormMode), PI_SAMPLER_INFO_ADDRESSING_MODE, - static_cast(m_AddrMode), + static_cast(MAddrMode), PI_SAMPLER_INFO_FILTER_MODE, - static_cast(m_FiltMode), + static_cast(MFiltMode), 0}; RT::PiResult errcode_ret = PI_SUCCESS; @@ -69,18 +74,19 @@ RT::PiSampler sampler_impl::getOrCreateSampler(const context &Context) { errcode_ret); Plugin.checkPiResult(errcode_ret); - m_contextToSampler[Context] = resultSampler; + std::lock_guard Lock(MMutex); + MContextToSampler[Context] = resultSampler; - return m_contextToSampler[Context]; + return resultSampler; } -addressing_mode sampler_impl::get_addressing_mode() const { return m_AddrMode; } +addressing_mode sampler_impl::get_addressing_mode() const { return MAddrMode; } -filtering_mode sampler_impl::get_filtering_mode() const { return m_FiltMode; } +filtering_mode sampler_impl::get_filtering_mode() const { return MFiltMode; } coordinate_normalization_mode sampler_impl::get_coordinate_normalization_mode() const { - return m_CoordNormMode; + return MCoordNormMode; } } // namespace detail diff --git a/sycl/source/detail/sampler_impl.hpp b/sycl/source/detail/sampler_impl.hpp index 271942d39ad0b..750fe6774ded2 100644 --- a/sycl/source/detail/sampler_impl.hpp +++ b/sycl/source/detail/sampler_impl.hpp @@ -23,14 +23,6 @@ enum class coordinate_normalization_mode : unsigned int; namespace detail { class __SYCL_EXPORT sampler_impl { -public: - std::unordered_map m_contextToSampler; - -private: - coordinate_normalization_mode m_CoordNormMode; - addressing_mode m_AddrMode; - filtering_mode m_FiltMode; - public: sampler_impl(coordinate_normalization_mode normalizationMode, addressing_mode addressingMode, filtering_mode filteringMode); @@ -46,6 +38,16 @@ class __SYCL_EXPORT sampler_impl { RT::PiSampler getOrCreateSampler(const context &Context); ~sampler_impl(); + +private: + /// Protects all the fields that can be changed by class' methods. + mutex_class MMutex; + + std::unordered_map MContextToSampler; + + coordinate_normalization_mode MCoordNormMode; + addressing_mode MAddrMode; + filtering_mode MFiltMode; }; } // namespace detail