Skip to content

Commit 874f624

Browse files
authored
Merge pull request #5 from pytorch-labs/add_log
Add log.h
2 parents a875876 + a560ee7 commit 874f624

File tree

4 files changed

+373
-27
lines changed

4 files changed

+373
-27
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ project(Tokenizers)
2121

2222
option(TOKENIZERS_BUILD_TEST "Build tests" OFF)
2323

24+
# Ignore weak attribute warning
25+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes")
26+
2427
set(ABSL_ENABLE_INSTALL ON)
2528
set(ABSL_PROPAGATE_CXX_STD ON)
2629
set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE})

include/base64.h

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,8 @@ inline Error validate(uint32_t v) {
6969
}
7070

7171
inline Error decode(const std::string_view &input, std::string &output) {
72-
if (input.size() != 4) {
73-
fprintf(stderr, "input length must be 4, got %zu", input.size());
74-
return Error::Base64DecodeFailure;
75-
}
72+
TK_CHECK_OR_RETURN_ERROR(input.size() == 4, Base64DecodeFailure,
73+
"input length must be 4, got %zu", input.size());
7674

7775
uint32_t val = 0;
7876

@@ -104,10 +102,8 @@ inline Error decode(const std::string_view &input, std::string &output) {
104102

105103
inline Error decode_1_padding(const std::string_view &input,
106104
std::string &output) {
107-
if (input.size() != 3) {
108-
fprintf(stderr, "input length must be 3, got %zu", input.size());
109-
return Error::Base64DecodeFailure;
110-
}
105+
TK_CHECK_OR_RETURN_ERROR(input.size() == 3, Base64DecodeFailure,
106+
"input length must be 3, got %zu", input.size());
111107

112108
uint32_t val = 0;
113109

@@ -133,7 +129,8 @@ inline Error decode_1_padding(const std::string_view &input,
133129

134130
inline Error decode_2_padding(const std::string_view &input,
135131
std::string &output) {
136-
TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure);
132+
TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure,
133+
"input length must be 2, got %zu", input.size());
137134

138135
uint32_t val = 0;
139136

@@ -154,18 +151,13 @@ inline Error decode_2_padding(const std::string_view &input,
154151
} // namespace detail
155152

156153
inline tokenizers::Result<std::string> decode(const std::string_view &input) {
157-
if (input.empty()) {
158-
fprintf(stderr, "empty input");
159-
return Error::Base64DecodeFailure;
160-
}
154+
TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input");
161155

162156
// Faster than `input.size() % 4`.
163-
if ((input.size() & 3) != 0 || input.size() < 4) {
164-
fprintf(stderr,
165-
"input length must be larger than 4 and is multiple of 4, got %zu",
166-
input.size());
167-
return Error::Base64DecodeFailure;
168-
}
157+
TK_CHECK_OR_RETURN_ERROR(
158+
(input.size() & 3) == 0 && input.size() >= 4, Base64DecodeFailure,
159+
"input length must be larger than 4 and is multiple of 4, got %zu",
160+
input.size());
169161

170162
std::string output;
171163
output.reserve(input.size() / 4 * 3);

include/error.h

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#pragma once
1515

16+
#include "log.h"
1617
#include <stdint.h>
1718

1819
namespace tokenizers {
@@ -59,11 +60,14 @@ enum class Error : error_code_t {
5960
* TODO: Add logging support
6061
* @param[in] cond__ The condition to be checked, asserted as true.
6162
* @param[in] error__ Error enum value to return without the `Error::` prefix,
62-
* like `InvalidArgument`.
63+
* like `Base64DecodeFailure`.
64+
* @param[in] message__ Format string for the log error message.
65+
* @param[in] ... Optional additional arguments for the format string.
6366
*/
64-
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__) \
67+
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
6568
{ \
6669
if (!(cond__)) { \
70+
TK_LOG(Error, message__, ##__VA_ARGS__); \
6771
return ::tokenizers::Error::error__; \
6872
} \
6973
}
@@ -72,11 +76,80 @@ enum class Error : error_code_t {
7276
* If error__ is not Error::Ok, return the specified Error
7377
* TODO: Add logging support
7478
* @param[in] error__ Error enum value to return without the `Error::` prefix,
75-
* like `InvalidArgument`.
79+
* like `Base64DecodeFailure`.
80+
* @param[in] ... Optional format string for the log error message and its
81+
* arguments.
7682
*/
77-
#define TK_CHECK_OK_OR_RETURN_ERROR(error__) \
78-
{ \
79-
if (error__ != ::tokenizers::Error::Ok) { \
80-
return error__; \
83+
#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \
84+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__)
85+
86+
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
87+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
88+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, \
89+
4, 3, 2, 1) \
90+
(__VA_ARGS__)
91+
92+
/**
93+
* Internal only: Use TK_CHECK_OK_OR_RETURN_ERROR() instead.
94+
* This macro selects the correct version of
95+
* TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR based on the number of arguments passed.
96+
* It uses a trick with the preprocessor to count the number of arguments and
97+
* then selects the appropriate macro.
98+
*
99+
* The macro expansion uses __VA_ARGS__ to accept any number of arguments and
100+
* then appends them to TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_, followed by the
101+
* count of arguments. The count is determined by the macro
102+
* TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT which takes the arguments and
103+
* passes them along with a sequence of numbers (2, 1). The preprocessor then
104+
* matches this sequence to the correct number of arguments provided.
105+
*
106+
* If two arguments are passed, TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 is
107+
* selected, suitable for cases where an error code and a custom message are
108+
* provided. If only one argument is passed,
109+
* TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 is selected, which is used for cases
110+
* with just an error code.
111+
*
112+
* Usage:
113+
* TK_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1
114+
* TK_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2
115+
*/
116+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(_1, _2, _3, _4, _5, _6, \
117+
_7, _8, _9, _10, N, ...) \
118+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N
119+
120+
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
121+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \
122+
do { \
123+
const auto et_error__ = (error__); \
124+
if (et_error__ != ::tokenizers::Error::Ok) { \
125+
return et_error__; \
81126
} \
82-
}
127+
} while (0)
128+
129+
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
130+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \
131+
do { \
132+
const auto et_error__ = (error__); \
133+
if (et_error__ != ::tokenizers::Error::Ok) { \
134+
TK_LOG(Error, message__, ##__VA_ARGS__); \
135+
return et_error__; \
136+
} \
137+
} while (0)
138+
139+
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
140+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \
141+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
142+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \
143+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
144+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \
145+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
146+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \
147+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
148+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \
149+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
150+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \
151+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
152+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \
153+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
154+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \
155+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2

0 commit comments

Comments
 (0)