diff --git a/libcxx/benchmarks/algorithms/mismatch.bench.cpp b/libcxx/benchmarks/algorithms/mismatch.bench.cpp index 9274932a764c5..06289068bb049 100644 --- a/libcxx/benchmarks/algorithms/mismatch.bench.cpp +++ b/libcxx/benchmarks/algorithms/mismatch.bench.cpp @@ -10,6 +10,15 @@ #include #include +void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) { + Benchmark->DenseRange(1, 8); + for (size_t i = 16; i != 1 << 20; i *= 2) { + Benchmark->Arg(i - 1); + Benchmark->Arg(i); + Benchmark->Arg(i + 1); + } +} + // TODO: Look into benchmarking aligned and unaligned memory explicitly // (currently things happen to be aligned because they are malloced that way) template @@ -24,8 +33,8 @@ static void bm_mismatch(benchmark::State& state) { benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin())); } } -BENCHMARK(bm_mismatch)->DenseRange(1, 8)->Range(16, 1 << 20); -BENCHMARK(bm_mismatch)->DenseRange(1, 8)->Range(16, 1 << 20); -BENCHMARK(bm_mismatch)->DenseRange(1, 8)->Range(16, 1 << 20); +BENCHMARK(bm_mismatch)->Apply(BenchmarkSizes); +BENCHMARK(bm_mismatch)->Apply(BenchmarkSizes); +BENCHMARK(bm_mismatch)->Apply(BenchmarkSizes); BENCHMARK_MAIN(); diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h index 4eb693a1f2e9d..d933a84cada9e 100644 --- a/libcxx/include/__algorithm/mismatch.h +++ b/libcxx/include/__algorithm/mismatch.h @@ -64,7 +64,10 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __ constexpr size_t __unroll_count = 4; constexpr size_t __vec_size = __native_vector_size<_Tp>; using __vec = __simd_vector<_Tp, __vec_size>; + if (!__libcpp_is_constant_evaluated()) { + auto __orig_first1 = __first1; + auto __last2 = __first2 + (__last1 - __first1); while (static_cast(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] { __vec __lhs[__unroll_count]; __vec __rhs[__unroll_count]; @@ -84,8 +87,32 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __ __first1 += __unroll_count * __vec_size; __first2 += __unroll_count * __vec_size; } + + // check the remaining 0-3 vectors + while (static_cast(__last1 - __first1) >= __vec_size) { + if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2); + !std::__all_of(__cmp_res)) { + auto __offset = std::__find_first_not_set(__cmp_res); + return {__first1 + __offset, __first2 + __offset}; + } + __first1 += __vec_size; + __first2 += __vec_size; + } + + if (__last1 - __first1 == 0) + return {__first1, __first2}; + + // Check if we can load elements in front of the current pointer. If that's the case load a vector at + // (last - vector_size) to check the remaining elements + if (static_cast(__first1 - __orig_first1) >= __vec_size) { + __first1 = __last1 - __vec_size; + __first2 = __last2 - __vec_size; + auto __offset = + std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2)); + return {__first1 + __offset, __first2 + __offset}; + } // else loop over the elements individually } - // TODO: Consider vectorizing the tail + return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2); } diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp index e7f3994d977dc..55c9eea863c3f 100644 --- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp +++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp @@ -184,5 +184,33 @@ int main(int, char**) { } } + { // check the tail of the vectorized loop + for (size_t vec_size = 1; vec_size != 256; ++vec_size) { + { + std::vector lhs(256); + std::vector rhs(256); + + check(lhs, rhs, lhs.size()); + lhs.back() = 1; + check(lhs, rhs, lhs.size() - 1); + lhs.back() = 0; + rhs.back() = 1; + check(lhs, rhs, lhs.size() - 1); + rhs.back() = 0; + } + { + std::vector lhs(256); + std::vector rhs(256); + + check(lhs, rhs, lhs.size()); + lhs.back() = 1; + check(lhs, rhs, lhs.size() - 1); + lhs.back() = 0; + rhs.back() = 1; + check(lhs, rhs, lhs.size() - 1); + rhs.back() = 0; + } + } + } return 0; }