-
-
Notifications
You must be signed in to change notification settings - Fork 18.7k
Implemented NumbaExecutionEngine #61487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
aa42037
db9f3b0
4cb240d
97d9063
69e0e35
7365079
c605857
24a0615
b7a2ecb
545db65
6f4fb50
221cf7c
ed8dc7f
65b9d32
2703f86
347463e
77eb146
90f264f
f8f1166
bc2939b
176753b
cf3e392
a4bac18
ca91e89
e337cb8
30703b9
b39a8d1
2b59eeb
f59fb52
984a008
b08c361
8943e0a
f330473
c05b1a7
4b650d0
85e3dd3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,6 +178,60 @@ def apply( | |
""" | ||
|
||
|
||
class NumbaExecutionEngine(BaseExecutionEngine): | ||
""" | ||
Numba-based execution engine for pandas apply and map operations. | ||
""" | ||
|
||
@staticmethod | ||
def map( | ||
data: np.ndarray | Series | DataFrame, | ||
func, | ||
args: tuple, | ||
kwargs: dict, | ||
decorator: Callable | None, | ||
skip_na: bool, | ||
): | ||
""" | ||
Elementwise map for the Numba engine. Currently not supported. | ||
""" | ||
raise NotImplementedError("Numba map is not implemented yet.") | ||
|
||
@staticmethod | ||
def apply( | ||
data: np.ndarray | Series | DataFrame, | ||
func, | ||
args: tuple, | ||
kwargs: dict, | ||
decorator: Callable, | ||
axis: int | str, | ||
): | ||
""" | ||
Apply `func` along the given axis using Numba. | ||
""" | ||
engine_kwargs: dict[str, bool] | None = ( | ||
decorator if isinstance(decorator, dict) else None | ||
) | ||
|
||
looper_args, looper_kwargs = prepare_function_arguments( | ||
func, | ||
args, | ||
kwargs, | ||
num_required_args=1, | ||
) | ||
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has | ||
# incompatible type "Callable[..., Any] | str | list[Callable | ||
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | | ||
# list[Callable[..., Any] | str]]"; expected "Hashable" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think I understand this comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment is from the original logic. I didn't write it myself, but I believe it's referring to a type checker (likely mypy or pyright) complaining about the use of Let me know if you'd like me to remove the comment, and if other comments should be checked as well. |
||
nb_looper = generate_apply_looper( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally wouldn't abbreviate to nb, it's not super clear imho. Just calling this |
||
func, | ||
**get_jit_arguments(engine_kwargs), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can make this simpler if you change What you are doing now is to extract the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also noticed that |
||
) | ||
result = nb_looper(data, axis, *looper_args) | ||
# If we made the result 2-D, squeeze it back to 1-D | ||
return np.squeeze(result) | ||
|
||
|
||
def frame_apply( | ||
obj: DataFrame, | ||
func: AggFuncType, | ||
|
@@ -1094,23 +1148,19 @@ def wrapper(*args, **kwargs): | |
return wrapper | ||
|
||
if engine == "numba": | ||
args, kwargs = prepare_function_arguments( | ||
self.func, # type: ignore[arg-type] | ||
numba = import_optional_dependency("numba") | ||
|
||
if not hasattr(numba.jit, "__pandas_udf__"): | ||
numba.jit.__pandas_udf__ = NumbaExecutionEngine | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I think it'd be a simpler approach is to implement this logic here: https://github.com/pandas-dev/pandas/blob/main/pandas/core/frame.py#L10563 There, now we are considering two cases:
I would simplify that and just support engines with the engine interface
Since we want to support def apply(...):
if engine == "numba":
numba = import_optional_dependency("numba")
numba_jit = numba.jit(**engine_kwargs)
numba_jit.__pandas_udf__ = NumbaExecutionEngine From this point, all the code can pretend engine is going to be The challenge is that numba and the default engine share some code, and with this approach they'll be running independently. The default engine won't know anything about an When we move the default engine to a Does this approach makes sense to you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this makes sense thank you |
||
|
||
result = numba.jit.__pandas_udf__.apply( | ||
self.values, | ||
self.func, | ||
self.args, | ||
self.kwargs, | ||
num_required_args=1, | ||
) | ||
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has | ||
# incompatible type "Callable[..., Any] | str | list[Callable | ||
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | | ||
# list[Callable[..., Any] | str]]"; expected "Hashable" | ||
nb_looper = generate_apply_looper( | ||
self.func, # type: ignore[arg-type] | ||
**get_jit_arguments(engine_kwargs), | ||
engine_kwargs, | ||
self.axis, | ||
) | ||
result = nb_looper(self.values, self.axis, *args) | ||
# If we made the result 2-D, squeeze it back to 1-D | ||
result = np.squeeze(result) | ||
else: | ||
result = np.apply_along_axis( | ||
wrap_function(self.func), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the error when users write something like
df.map(func, engine=numba.jit)
. I think it'll be easier to understand for users if the message is something likeThe Numba engine is not implemented for the map method yet
.