-
-
Notifications
You must be signed in to change notification settings - Fork 151
Remove MaxAndArgmax
Op
#874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c8b52f1
to
402be24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is making good progress; much appreciated!
tests/link/test_numba.py
Outdated
def test_MaxAndArgmax(x, axes, exc): | ||
g = aem.MaxAndArgmax(axes)(x) | ||
|
||
if isinstance(g, list): | ||
g_fg = FunctionGraph(outputs=g) | ||
else: | ||
g_fg = FunctionGraph(outputs=[g]) | ||
|
||
cm = contextlib.suppress() if exc is None else pytest.warns(exc) | ||
with cm: | ||
compare_numba_and_py( | ||
g_fg, | ||
[ | ||
i.tag.test_value | ||
for i in g_fg.inputs | ||
if not isinstance(i, (SharedVariable, Constant)) | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add a new pytest.mark.parametrize
and parameter that cycles through the two now independent Op
s and runs the same tests.
@@ -906,71 +895,6 @@ def validate_grad_graph(func): | |||
assert softmax_grad_legacy not in ops | |||
|
|||
|
|||
def test_argmax_pushdown(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible to refactor these so that they work with each Op
separately.
@@ -745,400 +733,9 @@ def test_isnan(): | |||
f([[0, 1, 2]]) | |||
|
|||
|
|||
class TestMaxAndArgmax: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests need to be converted to work with the two Op
s , as well.
@@ -24,28 +22,6 @@ | |||
from tests.link.test_link import make_function | |||
|
|||
|
|||
class TestMaxAndArgmax: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@brandonwillard, thanks for the review! I will work on the requested changes. |
Hi @brandonwillard! I have the test that fails because |
d73f7e8
to
194bbbe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just made some small fixes and rebased onto main
. The errors I'm seeing now are due to the unfinished Numba implementations. Is that older, more cryptic error still present? (Nevermind, I see it in tests.tensor.test_math.TestMinMax.test_uint
. I'll look into it.)
Since these commits need to be restructured/squashed anyway, it would be good to reorganize them so that the Max
and ArgMax
updates and their corresponding tests are added first (e.g. in a single commit or one for each Op
along with their JAX and Numba implementations), then the removal/replacement of MaxAndArgmax
in another commit.
@brandonwillard Thanks for the review! I'll be working on those changes. |
Codecov Report
@@ Coverage Diff @@
## main #874 +/- ##
==========================================
- Coverage 79.25% 79.20% -0.05%
==========================================
Files 152 152
Lines 47882 47882
Branches 10909 10906 -3
==========================================
- Hits 37949 37926 -23
- Misses 7436 7454 +18
- Partials 2497 2502 +5
|
Hi @brandonwillard ! I seem to be able to fix the errors that occurred after removing
|
5a24cad
to
91e3c4b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like an extra file got added: aesara/tensor/\
.
# at_max, | ||
# at_min, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have the max
and min
functions been removed from TestLocalReduce
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brandonwillard I moved the MaxAndArgmax
functionality to Max
and now Max
is implemented as a COp
, so the tests for CAReduce
will not work with Max
.
aesara/tensor/math.py
Outdated
def R_op(self, inputs, eval_points): | ||
raise ValueError("Argmax is not a differentiable operation") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def R_op(self, inputs, eval_points): | |
raise ValueError("Argmax is not a differentiable operation") |
There's already a grad
implementation—albeit effectively non-functional—so I don't know if it helps to have an R_op
like this. Also, I don't think a ValueError
would be the best here.
If a gradient is undefined, I believe we should return an aesara.gradient.grad_undefined
-generated NullType
. Same with unimplemented gradients: we should use aesara.gradient.grad_not_implemented
.
…ly.github.com> modify Max Co-authored-by: Brandon T. Willard <[email protected]>
2a134f5
to
0b0d854
Compare
Co-authored-by: Brandon T. Willard <[email protected]>
f4e62f0
to
0d3c415
Compare
Co-authored-by: Brandon T. Willard <[email protected]>
0d3c415
to
114bf01
Compare
@brandonwillard I removed |
@brandonwillard, just soft reminder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like one of the commit messages got scrambled.
class Max(NonZeroCAReduce): | ||
nfunc_spec = ("max", 1, 1) | ||
class Max(COp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't we using the NonZeroCAReduce
base class for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brandonwillard Thank you for review. To fix this #874 (comment) problem with Max
I coped the code from MaxAndArgmax
and made it COp
. Do you think that it can add some problems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could interfere with existing code that assumes Max
is a subclass of NonZeroCAReduce
(e.g. rewrites, our JAX and/or Numba implementations, etc.)
Also, we generally want to make use of existing code. In this case, subclassing NonZeroCAReduce
could make adding a C implementation much easier that it otherwise would be.
MaxAndArgmax.debug = 0 | ||
Argmax.debug = 0 | ||
Max.debug = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do these do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, it looks like a piece of old code. I think I should to remove it.
Closes #765