Skip to content

Commit b6a5997

Browse files
committed
address comments
1 parent 293d48b commit b6a5997

File tree

5 files changed

+59
-68
lines changed

5 files changed

+59
-68
lines changed

sycl/tools/sycl-prof/collector.cpp

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,30 @@
1818
#include <thread>
1919
#include <unistd.h>
2020

21-
unsigned long process_id() { return static_cast<unsigned long>(getpid()); }
2221

2322
namespace chrono = std::chrono;
2423

2524
Writer *GWriter = nullptr;
2625

26+
struct Measurements {
27+
size_t TID;
28+
size_t PID;
29+
size_t TimeStamp;
30+
};
31+
32+
unsigned long process_id() { return static_cast<unsigned long>(getpid()); }
33+
34+
static Measurements measure() {
35+
size_t TID = std::hash<std::thread::id>{}(std::this_thread::get_id());
36+
size_t PID = process_id();
37+
auto Now = chrono::high_resolution_clock::now();
38+
size_t TS = chrono::time_point_cast<chrono::nanoseconds>(Now)
39+
.time_since_epoch()
40+
.count();
41+
42+
return Measurements{TID, PID, TS};
43+
}
44+
2745
XPTI_CALLBACK_API void piBeginEndCallback(uint16_t TraceType,
2846
xpti::trace_event_data_t *,
2947
xpti::trace_event_data_t *,
@@ -51,29 +69,21 @@ XPTI_CALLBACK_API void xptiTraceInit(unsigned int /*major_version*/,
5169

5270
if (std::string_view(StreamName) == "sycl.pi") {
5371
uint8_t StreamID = xptiRegisterStream(StreamName);
54-
xptiRegisterCallback(StreamID,
55-
(uint16_t)xpti::trace_point_type_t::function_begin,
72+
xptiRegisterCallback(StreamID, xpti::trace_function_begin,
5673
piBeginEndCallback);
57-
xptiRegisterCallback(StreamID,
58-
(uint16_t)xpti::trace_point_type_t::function_end,
74+
xptiRegisterCallback(StreamID, xpti::trace_function_end,
5975
piBeginEndCallback);
6076
} else if (std::string_view(StreamName) == "sycl") {
6177
uint8_t StreamID = xptiRegisterStream(StreamName);
62-
xptiRegisterCallback(StreamID,
63-
(uint16_t)xpti::trace_point_type_t::task_begin,
78+
xptiRegisterCallback(StreamID, xpti::trace_task_begin,
6479
taskBeginEndCallback);
65-
xptiRegisterCallback(StreamID, (uint16_t)xpti::trace_point_type_t::task_end,
66-
taskBeginEndCallback);
67-
xptiRegisterCallback(StreamID,
68-
(uint16_t)xpti::trace_point_type_t::wait_begin,
69-
waitBeginEndCallback);
70-
xptiRegisterCallback(StreamID, (uint16_t)xpti::trace_point_type_t::wait_end,
80+
xptiRegisterCallback(StreamID, xpti::trace_task_end, taskBeginEndCallback);
81+
xptiRegisterCallback(StreamID, xpti::trace_wait_begin,
7182
waitBeginEndCallback);
72-
xptiRegisterCallback(StreamID,
73-
(uint16_t)xpti::trace_point_type_t::barrier_begin,
83+
xptiRegisterCallback(StreamID, xpti::trace_wait_end, waitBeginEndCallback);
84+
xptiRegisterCallback(StreamID, xpti::trace_barrier_begin,
7485
waitBeginEndCallback);
75-
xptiRegisterCallback(StreamID,
76-
(uint16_t)xpti::trace_point_type_t::barrier_end,
86+
xptiRegisterCallback(StreamID, xpti::trace_barrier_end,
7787
waitBeginEndCallback);
7888
}
7989
}
@@ -85,13 +95,8 @@ XPTI_CALLBACK_API void piBeginEndCallback(uint16_t TraceType,
8595
xpti::trace_event_data_t *,
8696
uint64_t /*Instance*/,
8797
const void *UserData) {
88-
unsigned long TID = std::hash<std::thread::id>{}(std::this_thread::get_id());
89-
unsigned long PID = process_id();
90-
auto Now = chrono::high_resolution_clock::now();
91-
auto TS = chrono::time_point_cast<chrono::nanoseconds>(Now)
92-
.time_since_epoch()
93-
.count();
94-
if (TraceType == (uint16_t)xpti::trace_point_type_t::function_begin) {
98+
auto [TID, PID, TS] = measure();
99+
if (TraceType == xpti::trace_function_begin) {
95100
GWriter->writeBegin(static_cast<const char *>(UserData), "Plugin", PID, TID,
96101
TS);
97102
} else {
@@ -105,9 +110,6 @@ XPTI_CALLBACK_API void taskBeginEndCallback(uint16_t TraceType,
105110
xpti::trace_event_data_t *Event,
106111
uint64_t /*Instance*/,
107112
const void *) {
108-
unsigned long TID = std::hash<std::thread::id>{}(std::this_thread::get_id());
109-
unsigned long PID = process_id();
110-
111113
std::string_view Name = "unknown";
112114

113115
xpti::metadata_t *Metadata = xptiQueryMetadata(Event);
@@ -118,12 +120,8 @@ XPTI_CALLBACK_API void taskBeginEndCallback(uint16_t TraceType,
118120
}
119121
}
120122

121-
auto Now = chrono::high_resolution_clock::now();
122-
auto TS = chrono::time_point_cast<chrono::nanoseconds>(Now)
123-
.time_since_epoch()
124-
.count();
125-
126-
if (TraceType == (uint16_t)xpti::trace_point_type_t::task_begin) {
123+
auto [TID, PID, TS] = measure();
124+
if (TraceType == xpti::trace_task_begin) {
127125
GWriter->writeBegin(Name, "SYCL", PID, TID, TS);
128126
} else {
129127
GWriter->writeEnd(Name, "SYCL", PID, TID, TS);
@@ -135,14 +133,9 @@ XPTI_CALLBACK_API void waitBeginEndCallback(uint16_t TraceType,
135133
xpti::trace_event_data_t *,
136134
uint64_t /*Instance*/,
137135
const void *UserData) {
138-
unsigned long TID = std::hash<std::thread::id>{}(std::this_thread::get_id());
139-
unsigned long PID = process_id();
140-
auto Now = chrono::high_resolution_clock::now();
141-
auto TS = chrono::time_point_cast<chrono::nanoseconds>(Now)
142-
.time_since_epoch()
143-
.count();
144-
if (TraceType == (uint16_t)xpti::trace_point_type_t::wait_begin ||
145-
TraceType == (uint16_t)xpti::trace_point_type_t::barrier_begin) {
136+
auto [TID, PID, TS] = measure();
137+
if (TraceType == xpti::trace_wait_begin ||
138+
TraceType == xpti::trace_barrier_begin) {
146139
GWriter->writeBegin(static_cast<const char *>(UserData), "SYCL", PID, TID,
147140
TS);
148141
} else {

sycl/tools/sycl-prof/writer.hpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ class Writer {
1717
virtual void init() = 0;
1818
virtual void finalize() = 0;
1919
virtual void writeBegin(std::string_view Name, std::string_view Category,
20-
unsigned long PID, unsigned long TID,
21-
unsigned long TimeStamp) = 0;
20+
size_t PID, size_t TID, size_t TimeStamp) = 0;
2221
virtual void writeEnd(std::string_view Name, std::string_view Category,
23-
unsigned long PID, unsigned long TID,
24-
unsigned long TimeStamp) = 0;
22+
size_t PID, size_t TID, size_t TimeStamp) = 0;
2523
virtual ~Writer() = default;
2624
};
2725

@@ -36,9 +34,8 @@ class JSONWriter : public Writer {
3634
MOutFile << " \"traceEvents\": [\n";
3735
}
3836

39-
void writeBegin(std::string_view Name, std::string_view Category,
40-
unsigned long PID, unsigned long TID,
41-
unsigned long TimeStamp) override {
37+
void writeBegin(std::string_view Name, std::string_view Category, size_t PID,
38+
size_t TID, size_t TimeStamp) override {
4239
std::lock_guard _{MWriteMutex};
4340

4441
if (!MOutFile.is_open())
@@ -53,9 +50,8 @@ class JSONWriter : public Writer {
5350
MOutFile << std::endl;
5451
}
5552

56-
void writeEnd(std::string_view Name, std::string_view Category,
57-
unsigned long PID, unsigned long TID,
58-
unsigned long TimeStamp) override {
53+
void writeEnd(std::string_view Name, std::string_view Category, size_t PID,
54+
size_t TID, size_t TimeStamp) override {
5955
std::lock_guard _{MWriteMutex};
6056

6157
if (!MOutFile.is_open())

sycl/tools/sycl-sanitize/collector.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,10 @@ XPTI_CALLBACK_API void xptiTraceInit(unsigned int /*major_version*/,
143143
if (std::string_view(StreamName) == "sycl.pi.debug") {
144144
GS = new GlobalState;
145145
uint8_t StreamID = xptiRegisterStream(StreamName);
146-
xptiRegisterCallback(
147-
StreamID, (uint16_t)xpti::trace_point_type_t::function_with_args_begin,
148-
tpCallback);
149-
xptiRegisterCallback(
150-
StreamID, (uint16_t)xpti::trace_point_type_t::function_with_args_end,
151-
tpCallback);
146+
xptiRegisterCallback(StreamID, xpti::trace_function_with_args_begin,
147+
tpCallback);
148+
xptiRegisterCallback(StreamID, xpti::trace_function_with_args_end,
149+
tpCallback);
152150

153151
GS->ArgHandlerPostCall.set_piextUSMHostAlloc(handleUSMHostAlloc);
154152
GS->ArgHandlerPostCall.set_piextUSMDeviceAlloc(handleUSMDeviceAlloc);
@@ -199,16 +197,15 @@ XPTI_CALLBACK_API void tpCallback(uint16_t TraceType,
199197
GS->LastTracepoint.Line = 0;
200198
}
201199

202-
auto Type = static_cast<xpti::trace_point_type_t>(TraceType);
203200
// Lock while we capture information
204201
std::lock_guard<std::mutex> Lock(GS->IOMutex);
205202

206203
const auto *Data = static_cast<const xpti::function_with_args_t *>(UserData);
207204
const auto *Plugin = static_cast<pi_plugin *>(Data->user_data);
208-
if (Type == xpti::trace_point_type_t::function_with_args_begin) {
205+
if (TraceType == xpti::trace_function_with_args_begin) {
209206
GS->ArgHandlerPreCall.handle(Data->function_id, *Plugin, std::nullopt,
210207
Data->args_data);
211-
} else if (Type == xpti::trace_point_type_t::function_with_args_end) {
208+
} else if (TraceType == xpti::trace_function_with_args_end) {
212209
const pi_result Result = *static_cast<pi_result *>(Data->ret_data);
213210
GS->ArgHandlerPostCall.handle(Data->function_id, *Plugin, Result,
214211
Data->args_data);

sycl/tools/sycl-trace/pi_trace_collector.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,10 @@ XPTI_CALLBACK_API void xptiTraceInit(unsigned int /*major_version*/,
4343
const char *stream_name) {
4444
if (std::string_view(stream_name) == "sycl.pi.debug") {
4545
GStreamID = xptiRegisterStream(stream_name);
46-
xptiRegisterCallback(
47-
GStreamID, (uint16_t)xpti::trace_point_type_t::function_with_args_begin,
48-
tpCallback);
49-
xptiRegisterCallback(
50-
GStreamID, (uint16_t)xpti::trace_point_type_t::function_with_args_end,
51-
tpCallback);
46+
xptiRegisterCallback(GStreamID, xpti::trace_function_with_args_begin,
47+
tpCallback);
48+
xptiRegisterCallback(GStreamID, xpti::trace_function_with_args_end,
49+
tpCallback);
5250

5351
#define _PI_API(api) \
5452
ArgHandler.set##_##api( \
@@ -71,8 +69,7 @@ XPTI_CALLBACK_API void tpCallback(uint16_t TraceType,
7169
xpti::trace_event_data_t * /*Parent*/,
7270
xpti::trace_event_data_t * /*Event*/,
7371
uint64_t /*Instance*/, const void *UserData) {
74-
auto Type = static_cast<xpti::trace_point_type_t>(TraceType);
75-
if (Type == xpti::trace_point_type_t::function_with_args_end) {
72+
if (TraceType == xpti::trace_function_with_args_end) {
7673
// Lock while we print information
7774
std::lock_guard<std::mutex> Lock(GIOMutex);
7875

xpti/include/xpti/xpti_data_types.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,14 @@ constexpr uint16_t trace_edge_create =
668668
static_cast<uint16_t>(xpti::trace_point_type_t::edge_create);
669669
constexpr uint16_t trace_signal =
670670
static_cast<uint16_t>(xpti::trace_point_type_t::signal);
671+
constexpr uint16_t trace_function_begin =
672+
static_cast<uint16_t>(xpti::trace_point_type_t::function_begin);
673+
constexpr uint16_t trace_function_end =
674+
static_cast<uint16_t>(xpti::trace_point_type_t::function_end);
675+
constexpr uint16_t trace_function_with_args_begin =
676+
static_cast<uint16_t>(xpti::trace_point_type_t::function_with_args_begin);
677+
constexpr uint16_t trace_function_with_args_end =
678+
static_cast<uint16_t>(xpti::trace_point_type_t::function_with_args_end);
671679
constexpr uint16_t trace_offload_alloc_construct =
672680
static_cast<uint16_t>(xpti::trace_point_type_t::offload_alloc_construct);
673681
constexpr uint16_t trace_offload_alloc_associate =

0 commit comments

Comments
 (0)