@@ -110,14 +110,28 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
110
110
return dsc , out_strides
111
111
112
112
113
- def _complex_nd_fft (a , s , norm , out , forward , in_place , c2c , axes , batch_fft ):
113
+ def _complex_nd_fft (
114
+ a ,
115
+ s ,
116
+ norm ,
117
+ out ,
118
+ forward ,
119
+ in_place ,
120
+ c2c ,
121
+ axes ,
122
+ batch_fft ,
123
+ * ,
124
+ reversed_axes = True ,
125
+ ):
114
126
"""Computes complex-to-complex FFT of the input N-D array."""
115
127
116
128
len_axes = len (axes )
117
129
# OneMKL supports up to 3-dimensional FFT on GPU
118
130
# repeated axis in OneMKL FFT is not allowed
119
131
if len_axes > 3 or len (set (axes )) < len_axes :
120
- axes_chunk , shape_chunk = _extract_axes_chunk (axes , s , chunk_size = 3 )
132
+ axes_chunk , shape_chunk = _extract_axes_chunk (
133
+ axes , s , chunk_size = 3 , reversed_axes = reversed_axes
134
+ )
121
135
for i , (s_chunk , a_chunk ) in enumerate (zip (shape_chunk , axes_chunk )):
122
136
a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
123
137
# if out is used in an intermediate step, it will have memory
@@ -291,7 +305,7 @@ def _copy_array(x, complex_input):
291
305
return x , copy_flag
292
306
293
307
294
- def _extract_axes_chunk (a , s , chunk_size = 3 ):
308
+ def _extract_axes_chunk (a , s , chunk_size = 3 , reversed_axes = True ):
295
309
"""
296
310
Classify the first input into a list of lists with each list containing
297
311
only unique values in reverse order and its length is at most `chunk_size`.
@@ -362,7 +376,10 @@ def _extract_axes_chunk(a, s, chunk_size=3):
362
376
a_chunks .append (a_current_chunk [::- 1 ])
363
377
s_chunks .append (s_current_chunk [::- 1 ])
364
378
365
- return a_chunks [::- 1 ], s_chunks [::- 1 ]
379
+ if reversed_axes :
380
+ return a_chunks [::- 1 ], s_chunks [::- 1 ]
381
+
382
+ return a_chunks , s_chunks
366
383
367
384
368
385
def _fft (a , norm , out , forward , in_place , c2c , axes , batch_fft = True ):
@@ -531,9 +548,12 @@ def _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c):
531
548
expected_shape [axes [- 1 ]] = s [- 1 ] // 2 + 1
532
549
elif c2c :
533
550
expected_shape [axes [- 1 ]] = s [- 1 ]
534
- for s_i , axis in zip (s [- 2 ::- 1 ], axes [- 2 ::- 1 ]):
535
- expected_shape [axis ] = s_i
551
+ if r2c or c2c :
552
+ for s_i , axis in zip (s [- 2 ::- 1 ], axes [- 2 ::- 1 ]):
553
+ expected_shape [axis ] = s_i
536
554
if c2r :
555
+ for s_i , axis in zip (s [:- 1 ], axes [:- 1 ]):
556
+ expected_shape [axis ] = s_i
537
557
expected_shape [axes [- 1 ]] = s [- 1 ]
538
558
539
559
if out .shape != tuple (expected_shape ):
@@ -717,6 +737,7 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
717
737
c2c = True ,
718
738
axes = axes [:- 1 ],
719
739
batch_fft = a .ndim != len_axes - 1 ,
740
+ reversed_axes = False ,
720
741
)
721
742
a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
722
743
if c2r :
0 commit comments