From 98e28aa5242ccc65f0b2e716459ccb579afac65c Mon Sep 17 00:00:00 2001 From: iosetek Date: Mon, 15 Mar 2021 15:14:05 +0100 Subject: [PATCH] Test for proving Loading rule is not thread safe When multiple threads are trying to load rule at the same time even when they refer to different rule set a string representing rule is malformed. It's believed that this is caused by seclang-parser which contains plenty of global variables. To run the test simply enter 'test' directory and run "./unit_tests" --- test/unit/unit.cc | 143 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 140 insertions(+), 3 deletions(-) diff --git a/test/unit/unit.cc b/test/unit/unit.cc index 49aeabd204..4a9a8ea29d 100644 --- a/test/unit/unit.cc +++ b/test/unit/unit.cc @@ -13,12 +13,14 @@ * */ -#include -#include -#include +#include +#include #include +#include #include +#include +#include #include "modsecurity/rules_set.h" #include "modsecurity/modsecurity.h" @@ -115,6 +117,136 @@ void perform_unit_test(ModSecurityTest *test, UnitTest *t, } } +struct thread_data { + bool *done; + bool *success; + std::string *err; +}; + +void spawn_rule_loading_thread(thread_data d) { + modsecurity::RulesSet *r = new modsecurity::RulesSet(); + auto rules = R"(SecAction "id:900000,phase:1,pass,nolog,setvar:tx.paranoia_level=1")"; + *(d.success) = r->load(rules, "") == 1; + if (!*d.success) { + *d.err = r->getParserError(); + } + delete r; + *(d.done) = true; +} + +std::vector spawn_threads( + unsigned number_of_threads, + std::chrono::milliseconds delay) { + + auto threads = std::vector(); + + for (int i = 0; i < number_of_threads; i++) { + threads.push_back({new bool(false), new bool(false), new std::string("")}); + std::thread (spawn_rule_loading_thread, threads.back()).detach(); + + if (delay > std::chrono::milliseconds(0)) { + std::this_thread::sleep_for(delay); + } + } + + return threads; +} + +bool threads_finished(std::vector threads) { + for (auto t: threads) { + if (!*(t.done)) { + return false; + } + } + return true; +} + +bool threads_succeeded(std::vector threads) { + for (auto t: threads) { + if (!*(t.success)) { + return false; + } + } + return true; +} + +void assert_threads(std::string test_name, std::vector threads) { + std::string result = threads_succeeded(threads) ? "succeeded" : "failed"; + std::cout << "Test: '" << test_name << "' " << result << ".\n"; + if (!threads_succeeded(threads)) { + for (int i = 0; i < threads.size(); i++) { + std::cout << "thread [" << i << "] returned: '" << *(threads[i].err) << "'\n"; + } + } +} + +void clean_up(std::vector threads) { + for (auto t: threads) { + delete t.done; + delete t.success; + delete t.err; + } +} + +bool wait_for_threads( + std::string test_name, + std::vector threads, + std::chrono::seconds timeout) { + + auto start = std::chrono::system_clock::now(); + while (true) { + if (threads_finished(threads)) { + break; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (std::chrono::system_clock::now() - start > timeout) { + std::cout << "Test: '" << test_name << "' reached timeout. Failed.\n"; + return false; + } + } + return true; +} + +void test_1_overlaping_thread() { + std::string test_name = "Test 1 thread"; + + auto threads = spawn_threads(1, std::chrono::milliseconds(0)); + + auto finished = wait_for_threads(test_name, threads, std::chrono::seconds(1)); + + if (finished) { + assert_threads(test_name, threads); + } + clean_up(threads); + +} + +void test_3_overlaping_threads() { + std::string test_name = "Test 3 overlapping load rule threads"; + + auto threads = spawn_threads(3, std::chrono::milliseconds(0)); + + auto finished = wait_for_threads(test_name, threads, std::chrono::seconds(1)); + + if (finished) { + assert_threads(test_name, threads); + } + clean_up(threads); +} + +void test_3_non_overlaping_threads() { + std::string test_name = "Test 3 non overlapping load rule threads (delay between)"; + + auto threads = spawn_threads(3, std::chrono::milliseconds(100)); + + auto finished = wait_for_threads(test_name, threads, std::chrono::seconds(1)); + + if (finished) { + assert_threads(test_name, threads); + } + clean_up(threads); +} int main(int argc, char **argv) { int total = 0; @@ -224,6 +356,11 @@ int main(int argc, char **argv) { } delete vec; } + + std::cout << "\n\nExecuting thread tests.\n"; + test_1_overlaping_thread(); + test_3_overlaping_threads(); + test_3_non_overlaping_threads(); }