Skip to content

Commit 8cfdffe

Browse files
authored
fix: Boundary check for MessageQueueShm head index (#405)
1 parent 87f6f2a commit 8cfdffe

File tree

9 files changed

+353
-241
lines changed

9 files changed

+353
-241
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ set(
237237
src/response_sender.h
238238
src/pb_stub.h
239239
src/pb_stub.cc
240+
src/pb_stub_log.h
241+
src/pb_stub_log.cc
240242
src/pb_response_iterator.h
241243
src/pb_response_iterator.cc
242244
src/pb_cancel.cc

src/message_queue.h

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -32,14 +32,19 @@
3232
#include <boost/thread/thread_time.hpp>
3333
#include <cstddef>
3434

35+
#include "pb_utils.h"
3536
#include "shm_manager.h"
37+
#ifdef TRITON_PB_STUB
38+
#include "pb_stub_log.h"
39+
#endif
3640

3741
namespace triton { namespace backend { namespace python {
3842
namespace bi = boost::interprocess;
3943

4044
/// Struct holding the representation of a message queue inside the shared
4145
/// memory.
42-
/// \param size Total size of the message queue.
46+
/// \param size Total size of the message queue. Considered invalid after
47+
/// MessageQueue::LoadFromSharedMemory. Check DLIS-8378 for additional details.
4348
/// \param mutex Handle of the mutex variable protecting index.
4449
/// \param index Used element index.
4550
/// \param sem_empty Semaphore object counting the number of empty buffer slots.
@@ -110,7 +115,22 @@ class MessageQueue {
110115

111116
{
112117
bi::scoped_lock<bi::interprocess_mutex> lock{*MutexMutable()};
113-
Buffer()[Head()] = message;
118+
int head_idx = Head();
119+
// Additional check to avoid out of bounds read/write. Check DLIS-8378 for
120+
// additional details.
121+
if (head_idx < 0 || static_cast<uint32_t>(head_idx) >= Size()) {
122+
std::string error_msg =
123+
"internal error: message queue head index out of bounds. Expects "
124+
"positive integer less than the size of message queue " +
125+
std::to_string(Size()) + " but got " + std::to_string(head_idx);
126+
#ifdef TRITON_PB_STUB
127+
LOG_ERROR << error_msg;
128+
#else
129+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str());
130+
#endif
131+
return;
132+
}
133+
Buffer()[head_idx] = message;
114134
HeadIncrement();
115135
}
116136
SemFullMutable()->post();
@@ -145,7 +165,22 @@ class MessageQueue {
145165
}
146166
success = true;
147167

148-
Buffer()[Head()] = message;
168+
int head_idx = Head();
169+
// Additional check to avoid out of bounds read/write. Check DLIS-8378 for
170+
// additional details.
171+
if (head_idx < 0 || static_cast<uint32_t>(head_idx) >= Size()) {
172+
std::string error_msg =
173+
"internal error: message queue head index out of bounds. Expects "
174+
"positive integer less than the size of message queue " +
175+
std::to_string(Size()) + " but got " + std::to_string(head_idx);
176+
#ifdef TRITON_PB_STUB
177+
LOG_ERROR << error_msg;
178+
#else
179+
LOG_MESSAGE(TRITONSERVER_LOG_ERROR, error_msg.c_str());
180+
#endif
181+
return;
182+
}
183+
Buffer()[head_idx] = message;
149184
HeadIncrement();
150185
}
151186
SemFullMutable()->post();
@@ -244,7 +279,7 @@ class MessageQueue {
244279
}
245280

246281
private:
247-
std::size_t& Size() { return mq_shm_ptr_->size; }
282+
uint32_t Size() { return size_; }
248283
const bi::interprocess_mutex& Mutex() { return mq_shm_ptr_->mutex; }
249284
bi::interprocess_mutex* MutexMutable() { return &(mq_shm_ptr_->mutex); }
250285
int& Head() { return mq_shm_ptr_->head; }
@@ -273,6 +308,7 @@ class MessageQueue {
273308
MessageQueueShm* mq_shm_ptr_;
274309
T* mq_buffer_shm_ptr_;
275310
bi::managed_external_buffer::handle_t mq_handle_;
311+
uint32_t size_;
276312

277313
/// Create/load a Message queue.
278314
/// \param mq_shm Message queue representation in shared memory.
@@ -284,6 +320,7 @@ class MessageQueue {
284320
mq_buffer_shm_ptr_ = mq_buffer_shm_.data_.get();
285321
mq_shm_ptr_ = mq_shm_.data_.get();
286322
mq_handle_ = mq_shm_.handle_;
323+
size_ = mq_shm_ptr_->size;
287324
}
288325
};
289326
}}} // namespace triton::backend::python

src/pb_bls_cancel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "pb_bls_cancel.h"
2828

2929
#include "pb_stub.h"
30+
#include "pb_stub_log.h"
3031

3132
namespace triton { namespace backend { namespace python {
3233

src/pb_cancel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -27,6 +27,7 @@
2727
#include "pb_cancel.h"
2828

2929
#include "pb_stub.h"
30+
#include "pb_stub_log.h"
3031

3132
namespace triton { namespace backend { namespace python {
3233

src/pb_stub.cc

Lines changed: 1 addition & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "pb_preferred_memory.h"
5050
#include "pb_response_iterator.h"
5151
#include "pb_string.h"
52+
#include "pb_stub_log.h"
5253
#include "pb_utils.h"
5354
#include "response_sender.h"
5455
#include "scoped_defer.h"
@@ -1569,138 +1570,6 @@ Stub::ProcessBLSResponseDecoupled(std::unique_ptr<IPCMessage>& ipc_message)
15691570
}
15701571
}
15711572

1572-
std::unique_ptr<Logger> Logger::log_instance_;
1573-
1574-
std::unique_ptr<Logger>&
1575-
Logger::GetOrCreateInstance()
1576-
{
1577-
if (Logger::log_instance_.get() == nullptr) {
1578-
Logger::log_instance_ = std::make_unique<Logger>();
1579-
}
1580-
1581-
return Logger::log_instance_;
1582-
}
1583-
1584-
// Bound function, called from the python client
1585-
void
1586-
Logger::Log(const std::string& message, LogLevel level)
1587-
{
1588-
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
1589-
py::object frame = py::module_::import("inspect").attr("currentframe");
1590-
py::object caller_frame = frame();
1591-
py::object info = py::module_::import("inspect").attr("getframeinfo");
1592-
py::object caller_info = info(caller_frame);
1593-
py::object filename_python = caller_info.attr("filename");
1594-
std::string filename = filename_python.cast<std::string>();
1595-
py::object lineno = caller_info.attr("lineno");
1596-
uint32_t line = lineno.cast<uint32_t>();
1597-
1598-
if (!stub->StubToParentServiceActive()) {
1599-
Logger::GetOrCreateInstance()->Log(filename, line, level, message);
1600-
} else {
1601-
std::unique_ptr<PbLog> log_msg(new PbLog(filename, line, message, level));
1602-
stub->EnqueueLogRequest(log_msg);
1603-
}
1604-
}
1605-
1606-
// Called internally (.e.g. LOG_ERROR << "Error"; )
1607-
void
1608-
Logger::Log(
1609-
const std::string& filename, uint32_t lineno, LogLevel level,
1610-
const std::string& message)
1611-
{
1612-
// If the log monitor service is not active yet, format
1613-
// and pass messages to cerr
1614-
if (!BackendLoggingActive()) {
1615-
std::string path(filename);
1616-
size_t pos = path.rfind(std::filesystem::path::preferred_separator);
1617-
if (pos != std::string::npos) {
1618-
path = path.substr(pos + 1, std::string::npos);
1619-
}
1620-
#ifdef _WIN32
1621-
std::stringstream ss;
1622-
SYSTEMTIME system_time;
1623-
GetSystemTime(&system_time);
1624-
ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2)
1625-
<< system_time.wMonth << std::setw(2) << system_time.wDay << ' '
1626-
<< std::setw(2) << system_time.wHour << ':' << std::setw(2)
1627-
<< system_time.wMinute << ':' << std::setw(2) << system_time.wSecond
1628-
<< '.' << std::setw(6) << system_time.wMilliseconds * 1000 << ' '
1629-
<< static_cast<uint32_t>(GetCurrentProcessId()) << ' ' << path << ':'
1630-
<< lineno << "] ";
1631-
#else
1632-
std::stringstream ss;
1633-
struct timeval tv;
1634-
gettimeofday(&tv, NULL);
1635-
struct tm tm_time;
1636-
gmtime_r(((time_t*)&(tv.tv_sec)), &tm_time);
1637-
ss << LeadingLogChar(level) << std::setfill('0') << std::setw(2)
1638-
<< (tm_time.tm_mon + 1) << std::setw(2) << tm_time.tm_mday << " "
1639-
<< std::setw(2) << tm_time.tm_hour << ':' << std::setw(2)
1640-
<< tm_time.tm_min << ':' << std::setw(2) << tm_time.tm_sec << "."
1641-
<< std::setw(6) << tv.tv_usec << ' ' << static_cast<uint32_t>(getpid())
1642-
<< ' ' << path << ':' << lineno << "] ";
1643-
std::cerr << ss.str() << " " << message << std::endl;
1644-
#endif
1645-
} else {
1646-
// Ensure we do not create a stub instance before it has initialized
1647-
std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();
1648-
std::unique_ptr<PbLog> log_msg(new PbLog(filename, lineno, message, level));
1649-
stub->EnqueueLogRequest(log_msg);
1650-
}
1651-
}
1652-
1653-
void
1654-
Logger::LogInfo(const std::string& message)
1655-
{
1656-
Logger::Log(message, LogLevel::kInfo);
1657-
}
1658-
1659-
void
1660-
Logger::LogWarn(const std::string& message)
1661-
{
1662-
Logger::Log(message, LogLevel::kWarning);
1663-
}
1664-
1665-
void
1666-
Logger::LogError(const std::string& message)
1667-
{
1668-
Logger::Log(message, LogLevel::kError);
1669-
}
1670-
1671-
void
1672-
Logger::LogVerbose(const std::string& message)
1673-
{
1674-
Logger::Log(message, LogLevel::kVerbose);
1675-
}
1676-
1677-
const std::string
1678-
Logger::LeadingLogChar(const LogLevel& level)
1679-
{
1680-
switch (level) {
1681-
case LogLevel::kWarning:
1682-
return "W";
1683-
case LogLevel::kError:
1684-
return "E";
1685-
case LogLevel::kInfo:
1686-
case LogLevel::kVerbose:
1687-
default:
1688-
return "I";
1689-
}
1690-
}
1691-
1692-
void
1693-
Logger::SetBackendLoggingActive(bool status)
1694-
{
1695-
backend_logging_active_ = status;
1696-
}
1697-
1698-
bool
1699-
Logger::BackendLoggingActive()
1700-
{
1701-
return backend_logging_active_;
1702-
}
1703-
17041573
PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
17051574
{
17061575
py::class_<PbError, std::shared_ptr<PbError>> triton_error(

0 commit comments

Comments
 (0)