Skip to content

Slow performance with groupby using a custom DataArray grouper #8377

@alessioarena

Description

@alessioarena

What is your issue?

I have a code that calculates a per-pixel nearest neighbor match between two datasets, to then perform a groupby + aggregation.
The calculation I perform is generally lazy using dask.

I recently noticed a slow performance of groupby in this way, with lazy calculations taking in excess of 10 minutes for an index of approximately 4000 by 4000.

I did a bit of digging around and noticed that the slow line is this:

Timer unit: 1e-09 s

Total time: 0.263679 s
File: /env/lib/python3.10/site-packages/xarray/core/duck_array_ops.py
Function: array_equiv at line 260

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   260                                           def array_equiv(arr1, arr2):
   261                                               """Like np.array_equal, but also allows values to be NaN in both arrays"""
   262     22140   96490101.0   4358.2     36.6      arr1 = asarray(arr1)
   263     22140   34155953.0   1542.7     13.0      arr2 = asarray(arr2)
   264     22140  119855572.0   5413.5     45.5      lazy_equiv = lazy_array_equiv(arr1, arr2)
   265     22140    7390478.0    333.8      2.8      if lazy_equiv is None:
   266                                                   with warnings.catch_warnings():
   267                                                       warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
   268                                                       flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
   269                                                       return bool(flag_array.all())
   270                                               else:
   271     22140    5787053.0    261.4      2.2          return lazy_equiv

Total time: 242.247 s
File: /env/lib/python3.10/site-packages/xarray/core/indexing.py
Function: __getitem__ at line 1419

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1419                                               def __getitem__(self, key):
  1420     22140   26764337.0   1208.9      0.0          if not isinstance(key, VectorizedIndexer):
  1421                                                       # if possible, short-circuit when keys are effectively slice(None)
  1422                                                       # This preserves dask name and passes lazy array equivalence checks
  1423                                                       # (see duck_array_ops.lazy_array_equiv)
  1424     22140   10513930.0    474.9      0.0              rewritten_indexer = False
  1425     22140    4602305.0    207.9      0.0              new_indexer = []
  1426     66420   61804870.0    930.5      0.0              for idim, k in enumerate(key.tuple):
  1427     88560   78516641.0    886.6      0.0                  if isinstance(k, Iterable) and (
  1428     22140  151748667.0   6854.1      0.1                      not is_duck_dask_array(k)
  1429     22140        2e+11    1e+07     93.6                      and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
  1430                                                           ):
  1431                                                               new_indexer.append(slice(None))
  1432                                                               rewritten_indexer = True
  1433                                                           else:
  1434     44280   40322984.0    910.6      0.0                      new_indexer.append(k)
  1435     22140    4847251.0    218.9      0.0              if rewritten_indexer:
  1436                                                           key = type(key)(tuple(new_indexer))
  1437                                           
  1438     22140   24251221.0   1095.4      0.0          if isinstance(key, BasicIndexer):
  1439                                                       return self.array[key.tuple]
  1440     22140    9613954.0    434.2      0.0          elif isinstance(key, VectorizedIndexer):
  1441                                                       return self.array.vindex[key.tuple]
  1442                                                   else:
  1443     22140    8618414.0    389.3      0.0              assert isinstance(key, OuterIndexer)
  1444     22140   26601491.0   1201.5      0.0              key = key.tuple
  1445     22140    6010672.0    271.5      0.0              try:
  1446     22140        2e+10 678487.7      6.2                  return self.array[key]
  1447                                                       except NotImplementedError:
  1448                                                           # manual orthogonal indexing.
  1449                                                           # TODO: port this upstream into dask in a saner way.
  1450                                                           value = self.array
  1451                                                           for axis, subkey in reversed(list(enumerate(key))):
  1452                                                               value = value[(slice(None),) * axis + (subkey,)]
  1453                                                           return value

The test duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim])) is repeated multiple times, and despite that being decently fast it amounts to a lot of time that could be potentially minimized by introducing a prior test of equal length, like

                if isinstance(k, Iterable) and (
                    not is_duck_dask_array(k)
                    and len(k) == self.array.shape[idim]
                    and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
                ):

This would work better because, despite that test being performed by array_equiv, currently the array to test against is always created using np.arange, that being ultimately the bottleneck

         74992059 function calls (73375414 primitive calls) in 298.934 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    22140  225.296    0.010  225.296    0.010 {built-in method numpy.arange}
   177123    3.192    0.000    3.670    0.000 inspect.py:2920(__init__)
110702/110701    2.180    0.000    2.180    0.000 {built-in method numpy.asarray}
11690863/11668723    2.036    0.000    5.043    0.000 {built-in method builtins.isinstance}
   287827    1.876    0.000    3.768    0.000 utils.py:25(meta_from_array)
   132843    1.872    0.000    7.649    0.000 inspect.py:2280(_signature_from_function)
   974166    1.485    0.000    2.558    0.000 inspect.py:2637(__init__)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions