Skip to content

Commit 66aad37

Browse files
committed
reverse the order of individual FFTs in rfftn
1 parent 9b73305 commit 66aad37

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

dpnp/fft/dpnp_utils_fft.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,28 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
110110
return dsc, out_strides
111111

112112

113-
def _complex_nd_fft(a, s, norm, out, forward, in_place, c2c, axes, batch_fft):
113+
def _complex_nd_fft(
114+
a,
115+
s,
116+
norm,
117+
out,
118+
forward,
119+
in_place,
120+
c2c,
121+
axes,
122+
batch_fft,
123+
*,
124+
reversed_axes=True,
125+
):
114126
"""Computes complex-to-complex FFT of the input N-D array."""
115127

116128
len_axes = len(axes)
117129
# OneMKL supports up to 3-dimensional FFT on GPU
118130
# repeated axis in OneMKL FFT is not allowed
119131
if len_axes > 3 or len(set(axes)) < len_axes:
120-
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
132+
axes_chunk, shape_chunk = _extract_axes_chunk(
133+
axes, s, chunk_size=3, reversed_axes=reversed_axes
134+
)
121135
for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)):
122136
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
123137
# if out is used in an intermediate step, it will have memory
@@ -291,7 +305,7 @@ def _copy_array(x, complex_input):
291305
return x, copy_flag
292306

293307

294-
def _extract_axes_chunk(a, s, chunk_size=3):
308+
def _extract_axes_chunk(a, s, chunk_size=3, reversed_axes=True):
295309
"""
296310
Classify the first input into a list of lists with each list containing
297311
only unique values in reverse order and its length is at most `chunk_size`.
@@ -362,7 +376,10 @@ def _extract_axes_chunk(a, s, chunk_size=3):
362376
a_chunks.append(a_current_chunk[::-1])
363377
s_chunks.append(s_current_chunk[::-1])
364378

365-
return a_chunks[::-1], s_chunks[::-1]
379+
if reversed_axes:
380+
return a_chunks[::-1], s_chunks[::-1]
381+
382+
return a_chunks, s_chunks
366383

367384

368385
def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
@@ -531,9 +548,12 @@ def _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c):
531548
expected_shape[axes[-1]] = s[-1] // 2 + 1
532549
elif c2c:
533550
expected_shape[axes[-1]] = s[-1]
534-
for s_i, axis in zip(s[-2::-1], axes[-2::-1]):
535-
expected_shape[axis] = s_i
551+
if r2c or c2c:
552+
for s_i, axis in zip(s[-2::-1], axes[-2::-1]):
553+
expected_shape[axis] = s_i
536554
if c2r:
555+
for s_i, axis in zip(s[:-1], axes[:-1]):
556+
expected_shape[axis] = s_i
537557
expected_shape[axes[-1]] = s[-1]
538558

539559
if out.shape != tuple(expected_shape):
@@ -717,6 +737,7 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
717737
c2c=True,
718738
axes=axes[:-1],
719739
batch_fft=a.ndim != len_axes - 1,
740+
reversed_axes=False,
720741
)
721742
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
722743
if c2r:

dpnp/tests/test_fft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,13 @@ def test_repeated_axes(self, axes):
401401
result = dpnp.fft.fftn(ia, axes=axes)
402402
# Intel NumPy ignores repeated axes (mkl_fft-gh-104), handle it one by one
403403
expected = a
404-
for ii in axes:
404+
for ii in axes[::-1]:
405405
expected = numpy.fft.fft(expected, axis=ii)
406406
assert_dtype_allclose(result, expected)
407407

408408
# inverse FFT
409409
result = dpnp.fft.ifftn(result, axes=axes)
410-
for ii in axes:
410+
for ii in axes[::-1]:
411411
expected = numpy.fft.ifft(expected, axis=ii)
412412
assert_dtype_allclose(result, expected)
413413

@@ -905,7 +905,7 @@ def test_repeated_axes(self, axes):
905905

906906
# inverse FFT
907907
result = dpnp.fft.irfftn(result, axes=axes)
908-
for ii in axes[-2::-1]:
908+
for ii in axes[:-1]:
909909
expected = numpy.fft.ifft(expected, axis=ii)
910910
expected = numpy.fft.irfft(expected, axis=axes[-1])
911911
assert_dtype_allclose(result, expected)
@@ -924,7 +924,7 @@ def test_repeated_axes_with_s(self, axes, s):
924924
assert_dtype_allclose(result, expected)
925925

926926
result = dpnp.fft.irfftn(result, s=s, axes=axes)
927-
for jj, ii in zip(s[-2::-1], axes[-2::-1]):
927+
for jj, ii in zip(s[:-1], axes[:-1]):
928928
expected = numpy.fft.ifft(expected, n=jj, axis=ii)
929929
expected = numpy.fft.irfft(expected, n=s[-1], axis=axes[-1])
930930
assert_dtype_allclose(result, expected)
@@ -946,7 +946,7 @@ def test_out(self, axes, s):
946946
assert_dtype_allclose(result, expected)
947947

948948
# inverse FFT
949-
for jj, ii in zip(s[-2::-1], axes[-2::-1]):
949+
for jj, ii in zip(s[:-1], axes[:-1]):
950950
expected = numpy.fft.ifft(expected, n=jj, axis=ii)
951951
expected = numpy.fft.irfft(expected, n=s[-1], axis=axes[-1])
952952
out = dpnp.empty(expected.shape, dtype=numpy.float32)

0 commit comments

Comments
 (0)