Skip to content

Commit d5f67fc

Browse files
authored
reverse the order of individual FFTs in rfftn (#2524)
For `numpy.fft.fftn` and `numpy.fft.ifftn`, individual FFTs over `axes` are performed in [reverse order](https://github.com/numpy/numpy/blob/v2.2.0/numpy/fft/_pocketfft.py#L739). Similarly, for `numpy.fft.rfftn`, individual FFTs are performed in [reverse order](https://github.com/numpy/numpy/blob/v2.2.0/numpy/fft/_pocketfft.py#L1382). However, for `numpy.fft.irfftn` individual FFTs are performed in [forward order](https://github.com/numpy/numpy/blob/v2.2.0/numpy/fft/_pocketfft.py#L1600).
1 parent afd5c6d commit d5f67fc

File tree

3 files changed

+37
-14
lines changed

3 files changed

+37
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2424
* Aligned the license expression with `PEP-639` [#2511](https://github.com/IntelPython/dpnp/pull/2511)
2525
* Bumped oneMKL version up to `v0.8` [#2514](https://github.com/IntelPython/dpnp/pull/2514)
2626
* Removed the use of class template argument deduction for alias template to conform to the C++17 standard [#2517](https://github.com/IntelPython/dpnp/pull/2517)
27+
* Changed th order of individual FFTs over `axes` for `dpnp.fft.irfftn` to be in forward order [#2524](https://github.com/IntelPython/dpnp/pull/2524)
2728

2829
### Deprecated
2930

dpnp/fft/dpnp_utils_fft.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,28 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
110110
return dsc, out_strides
111111

112112

113-
def _complex_nd_fft(a, s, norm, out, forward, in_place, c2c, axes, batch_fft):
113+
def _complex_nd_fft(
114+
a,
115+
s,
116+
norm,
117+
out,
118+
forward,
119+
in_place,
120+
c2c,
121+
axes,
122+
batch_fft,
123+
*,
124+
reversed_axes=True,
125+
):
114126
"""Computes complex-to-complex FFT of the input N-D array."""
115127

116128
len_axes = len(axes)
117129
# OneMKL supports up to 3-dimensional FFT on GPU
118130
# repeated axis in OneMKL FFT is not allowed
119131
if len_axes > 3 or len(set(axes)) < len_axes:
120-
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
132+
axes_chunk, shape_chunk = _extract_axes_chunk(
133+
axes, s, chunk_size=3, reversed_axes=reversed_axes
134+
)
121135
for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)):
122136
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
123137
# if out is used in an intermediate step, it will have memory
@@ -291,10 +305,10 @@ def _copy_array(x, complex_input):
291305
return x, copy_flag
292306

293307

294-
def _extract_axes_chunk(a, s, chunk_size=3):
308+
def _extract_axes_chunk(a, s, chunk_size=3, reversed_axes=True):
295309
"""
296310
Classify the first input into a list of lists with each list containing
297-
only unique values in reverse order and its length is at most `chunk_size`.
311+
only unique values and its length is at most `chunk_size`.
298312
The second input is also classified into a list of lists with each list
299313
containing the corresponding values of the first input.
300314
@@ -306,13 +320,14 @@ def _extract_axes_chunk(a, s, chunk_size=3):
306320
The second input.
307321
chunk_size : int
308322
Maximum number of elements in each chunk.
323+
reversed_axes : bool
324+
If True, the output chunks will be in reverse order.
309325
310326
Return
311327
------
312328
out : a tuple of two lists
313329
The first element of output is a list of lists with each list
314-
containing only unique values in revere order and its length is
315-
at most `chunk_size`.
330+
containing only unique values and its length is at most `chunk_size`.
316331
The second element of output is a list of lists with each list
317332
containing the corresponding values of the first input.
318333
@@ -362,7 +377,10 @@ def _extract_axes_chunk(a, s, chunk_size=3):
362377
a_chunks.append(a_current_chunk[::-1])
363378
s_chunks.append(s_current_chunk[::-1])
364379

365-
return a_chunks[::-1], s_chunks[::-1]
380+
if reversed_axes:
381+
return a_chunks[::-1], s_chunks[::-1]
382+
383+
return a_chunks, s_chunks
366384

367385

368386
def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
@@ -531,9 +549,12 @@ def _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c):
531549
expected_shape[axes[-1]] = s[-1] // 2 + 1
532550
elif c2c:
533551
expected_shape[axes[-1]] = s[-1]
534-
for s_i, axis in zip(s[-2::-1], axes[-2::-1]):
535-
expected_shape[axis] = s_i
552+
if r2c or c2c:
553+
for s_i, axis in zip(s[-2::-1], axes[-2::-1]):
554+
expected_shape[axis] = s_i
536555
if c2r:
556+
for s_i, axis in zip(s[:-1], axes[:-1]):
557+
expected_shape[axis] = s_i
537558
expected_shape[axes[-1]] = s[-1]
538559

539560
if out.shape != tuple(expected_shape):
@@ -717,6 +738,7 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
717738
c2c=True,
718739
axes=axes[:-1],
719740
batch_fft=a.ndim != len_axes - 1,
741+
reversed_axes=False,
720742
)
721743
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
722744
if c2r:

dpnp/tests/test_fft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,13 @@ def test_repeated_axes(self, axes):
401401
result = dpnp.fft.fftn(ia, axes=axes)
402402
# Intel NumPy ignores repeated axes (mkl_fft-gh-104), handle it one by one
403403
expected = a
404-
for ii in axes:
404+
for ii in axes[::-1]:
405405
expected = numpy.fft.fft(expected, axis=ii)
406406
assert_dtype_allclose(result, expected)
407407

408408
# inverse FFT
409409
result = dpnp.fft.ifftn(result, axes=axes)
410-
for ii in axes:
410+
for ii in axes[::-1]:
411411
expected = numpy.fft.ifft(expected, axis=ii)
412412
assert_dtype_allclose(result, expected)
413413

@@ -893,7 +893,7 @@ def test_repeated_axes(self, axes):
893893

894894
# inverse FFT
895895
result = dpnp.fft.irfftn(result, axes=axes)
896-
for ii in axes[-2::-1]:
896+
for ii in axes[:-1]:
897897
expected = numpy.fft.ifft(expected, axis=ii)
898898
expected = numpy.fft.irfft(expected, axis=axes[-1])
899899
assert_dtype_allclose(result, expected)
@@ -912,7 +912,7 @@ def test_repeated_axes_with_s(self, axes, s):
912912
assert_dtype_allclose(result, expected)
913913

914914
result = dpnp.fft.irfftn(result, s=s, axes=axes)
915-
for jj, ii in zip(s[-2::-1], axes[-2::-1]):
915+
for jj, ii in zip(s[:-1], axes[:-1]):
916916
expected = numpy.fft.ifft(expected, n=jj, axis=ii)
917917
expected = numpy.fft.irfft(expected, n=s[-1], axis=axes[-1])
918918
assert_dtype_allclose(result, expected)
@@ -934,7 +934,7 @@ def test_out(self, axes, s):
934934
assert_dtype_allclose(result, expected)
935935

936936
# inverse FFT
937-
for jj, ii in zip(s[-2::-1], axes[-2::-1]):
937+
for jj, ii in zip(s[:-1], axes[:-1]):
938938
expected = numpy.fft.ifft(expected, n=jj, axis=ii)
939939
expected = numpy.fft.irfft(expected, n=s[-1], axis=axes[-1])
940940
out = dpnp.empty(expected.shape, dtype=numpy.float32)

0 commit comments

Comments
 (0)