Skip to content

PERF: lexsort_depth #47511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions asv_bench/benchmarks/multiindex_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def setup(self):
self.mi_small = MultiIndex.from_product(
[np.arange(100), list("A"), list("A")], names=["one", "two", "three"]
)
self.mi_wide = MultiIndex.from_tuples(
[
np.hstack([2, np.ones(99)]),
np.ones(100),
np.ones(100),
],
)

def time_large_get_loc(self):
self.mi_large.get_loc((999, 19, "Z"))
Expand All @@ -32,6 +39,9 @@ def time_large_get_loc_warm(self):
for _ in range(1000):
self.mi_large.get_loc((999, 19, "Z"))

def time_wide_get_loc(self):
self.mi_wide.get_loc((1, 1))

def time_med_get_loc(self):
self.mi_med.get_loc((999, 9, "A"))

Expand Down
1 change: 1 addition & 0 deletions pandas/_libs/algos.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def unique_deltas(
arr: np.ndarray, # const int64_t[:]
) -> np.ndarray: ... # np.ndarray[np.int64, ndim=1]
def is_lexsorted(list_of_arrays: list[npt.NDArray[np.int64]]) -> bool: ...
def lexsort_depth(list_of_arrays: list[npt.NDArray[np.int64]]) -> int: ...
def groupsort_indexer(
index: np.ndarray, # const int64_t[:]
ngroups: int,
Expand Down
48 changes: 46 additions & 2 deletions pandas/_libs/algos.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,11 @@ def is_lexsorted(list_of_arrays: list) -> bint:
assert arr.dtype.name == 'int64'
vecs[i] = <int64_t*>cnp.PyArray_DATA(arr)

# Assume uniqueness??
with nogil:
for i in range(1, n):
for k in range(nlevels):
cur = vecs[k][i]
pre = vecs[k][i -1]
pre = vecs[k][i-1]
if cur == pre:
continue
elif cur > pre:
Expand All @@ -186,6 +185,51 @@ def is_lexsorted(list_of_arrays: list) -> bint:
return result


@cython.wraparound(False)
@cython.boundscheck(False)
def lexsort_depth(list_of_arrays: list) -> Py_ssize_t:
"""
Same as `is_lexsorted`, but keeps track of lexsort depth
as we iterate through the elements, and exits early if depth
is zero.
"""
cdef:
Py_ssize_t i, depth, k
Py_ssize_t n, nlevels
int64_t cur, pre
ndarray arr

nlevels = len(list_of_arrays)
n = len(list_of_arrays[0])

cdef int64_t **vecs = <int64_t**>malloc(nlevels * sizeof(int64_t*))
for i in range(nlevels):
arr = list_of_arrays[i]
assert arr.dtype.name == 'int64'
vecs[i] = <int64_t*>cnp.PyArray_DATA(arr)

with nogil:
depth = nlevels
for i in range(1, n):
k = 0
while k < depth:
cur = vecs[k][i]
pre = vecs[k][i-1]
if cur == pre:
k += 1
continue
elif cur > pre:
break
else:
depth = min(k, depth)
k += 1
if depth == 0:
# Depth can't increase, so if we've reached 0, break outer loop.
break
free(vecs)
return depth


@cython.boundscheck(False)
@cython.wraparound(False)
def groupsort_indexer(const intp_t[:] index, Py_ssize_t ngroups):
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,10 +3851,7 @@ def drop_duplicates(self, keep: str | bool = "first") -> MultiIndex:
def _lexsort_depth(codes: list[np.ndarray], nlevels: int) -> int:
"""Count depth (up to a maximum of `nlevels`) with which codes are lexsorted."""
int64_codes = [ensure_int64(level_codes) for level_codes in codes]
for k in range(nlevels, 0, -1):
if libalgos.is_lexsorted(int64_codes[:k]):
return k
return 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the perf boost come from moving the loop into cython or from using a different algorithm?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

different algorithm - the main one first checks whether all levels are lexsorted, then checks whether the first
n_levels-1 are lexsorted, then checks whether the first n_levels-2 are lexsorted, etc. So it requires going
through the whole array many times. The algo in this branch just needs to go through the array once, keeping
track of the depth along the way, and exiting early if possible.

return libalgos.lexsort_depth(int64_codes)


def sparsify_labels(label_list, start: int = 0, sentinel=""):
Expand Down