diff --git a/libcxx/include/__algorithm/ranges_starts_with.h b/libcxx/include/__algorithm/ranges_starts_with.h index ae145d59010ae..7243d3921aa62 100644 --- a/libcxx/include/__algorithm/ranges_starts_with.h +++ b/libcxx/include/__algorithm/ranges_starts_with.h @@ -9,15 +9,19 @@ #ifndef _LIBCPP___ALGORITHM_RANGES_STARTS_WITH_H #define _LIBCPP___ALGORITHM_RANGES_STARTS_WITH_H -#include <__algorithm/in_in_result.h> +#include <__algorithm/ranges_equal.h> #include <__algorithm/ranges_mismatch.h> #include <__config> #include <__functional/identity.h> #include <__functional/ranges_operations.h> +#include <__functional/reference_wrapper.h> #include <__iterator/concepts.h> +#include <__iterator/distance.h> #include <__iterator/indirectly_comparable.h> +#include <__iterator/next.h> #include <__ranges/access.h> #include <__ranges/concepts.h> +#include <__ranges/size.h> #include <__utility/move.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -49,14 +53,37 @@ struct __starts_with { _Pred __pred = {}, _Proj1 __proj1 = {}, _Proj2 __proj2 = {}) { - return __mismatch::__go( + if constexpr (sized_sentinel_for<_Sent1, _Iter1> && sized_sentinel_for<_Sent2, _Iter2>) { + auto __n1 = ranges::distance(__first1, __last1); + auto __n2 = ranges::distance(__first2, __last2); + if (__n2 == 0) { + return true; + } + if (__n2 > __n1) { + return false; + } + + if constexpr (contiguous_iterator<_Iter1> && contiguous_iterator<_Iter2>) { + auto __end1 = ranges::next(__first1, __n2); + return ranges::equal( + std::move(__first1), + std::move(__end1), + std::move(__first2), + std::move(__last2), + std::ref(__pred), + std::ref(__proj1), + std::ref(__proj2)); + } + } + + return ranges::mismatch( std::move(__first1), std::move(__last1), std::move(__first2), - std::move(__last2), - __pred, - __proj1, - __proj2) + __last2, + std::ref(__pred), + std::ref(__proj1), + std::ref(__proj2)) .in2 == __last2; } @@ -68,17 +95,40 @@ struct __starts_with { requires indirectly_comparable, iterator_t<_Range2>, _Pred, _Proj1, _Proj2> [[nodiscard]] _LIBCPP_HIDE_FROM_ABI static constexpr bool operator()(_Range1&& __range1, _Range2&& __range2, _Pred __pred = {}, _Proj1 __proj1 = {}, _Proj2 __proj2 = {}) { - return __mismatch::__go( + if constexpr (sized_range<_Range1> && sized_range<_Range2>) { + auto __n1 = ranges::size(__range1); + auto __n2 = ranges::size(__range2); + if (__n2 == 0) { + return true; + } + if (__n2 > __n1) { + return false; + } + + if constexpr (contiguous_range<_Range1> && contiguous_range<_Range2>) { + return ranges::equal( + ranges::begin(__range1), + ranges::next(ranges::begin(__range1), __n2), + ranges::begin(__range2), + ranges::end(__range2), + std::ref(__pred), + std::ref(__proj1), + std::ref(__proj2)); + } + } + + return ranges::mismatch( ranges::begin(__range1), ranges::end(__range1), ranges::begin(__range2), ranges::end(__range2), - __pred, - __proj1, - __proj2) + std::ref(__pred), + std::ref(__proj1), + std::ref(__proj2)) .in2 == ranges::end(__range2); } }; + inline namespace __cpo { inline constexpr auto starts_with = __starts_with{}; } // namespace __cpo diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.starts_with/ranges.starts_with.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.starts_with/ranges.starts_with.pass.cpp index 0f2284edde81c..cfc2e37d596f8 100644 --- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.starts_with/ranges.starts_with.pass.cpp +++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.starts_with/ranges.starts_with.pass.cpp @@ -216,12 +216,17 @@ constexpr void test_iterators() { constexpr bool test() { types::for_each(types::cpp20_input_iterator_list{}, []() { types::for_each(types::cpp20_input_iterator_list{}, []() { - if constexpr (std::forward_iterator && std::forward_iterator) + if constexpr (std::forward_iterator && std::forward_iterator) { test_iterators(); - if constexpr (std::forward_iterator) + } + if constexpr (std::forward_iterator) { + test_iterators, Iter2, Iter2>(); test_iterators, Iter2, Iter2>(); - if constexpr (std::forward_iterator) + } + if constexpr (std::forward_iterator) { + test_iterators>(); test_iterators>(); + } test_iterators, Iter2, sized_sentinel>(); }); });