|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import pytest
|
20 |
| -from numpy.testing import assert_array_equal |
| 20 | +from numpy.testing import ( |
| 21 | + assert_, assert_raises_regex, assert_array_equal |
| 22 | +) |
21 | 23 |
|
22 | 24 | import dpctl
|
23 | 25 | import dpctl.tensor as dpt
|
@@ -1067,36 +1069,65 @@ def test_swapaxes_2d():
|
1067 | 1069 |
|
1068 | 1070 | assert_array_equal(exp, dpt.asnumpy(res))
|
1069 | 1071 |
|
1070 |
| - |
1071 |
| -def test_moveaxis_1axis(): |
1072 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1073 |
| - exp = np.moveaxis(x, 0, -1) |
1074 |
| - |
1075 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1076 |
| - res = dpt.moveaxis(y, 0, -1) |
1077 |
| - |
1078 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
1079 |
| - |
1080 |
| - |
1081 |
| -def test_moveaxis_2axes(): |
1082 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1083 |
| - exp = np.moveaxis(x, [0, 1], [-1, -2]) |
1084 |
| - |
1085 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1086 |
| - res = dpt.moveaxis(y, [0, 1], [-1, -2]) |
1087 |
| - |
1088 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
1089 |
| - |
1090 |
| - |
1091 |
| -def test_moveaxis_3axes(): |
1092 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1093 |
| - exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3]) |
1094 |
| - |
1095 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1096 |
| - res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3]) |
1097 |
| - |
1098 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
1099 |
| - |
| 1072 | +def test_moveaxis_move_to_end(): |
| 1073 | + x = dpt.reshape(dpt.arange(5*6*7),(5, 6, 7)) |
| 1074 | + for source, expected in [(0, (6, 7, 5)), |
| 1075 | + (1, (5, 7, 6)), |
| 1076 | + (2, (5, 6, 7)), |
| 1077 | + (-1, (5, 6, 7))]: |
| 1078 | + actual = dpt.moveaxis(x, source, -1).shape |
| 1079 | + assert_(actual, expected) |
| 1080 | + |
| 1081 | +def test_moveaxis_new_position(): |
| 1082 | + x = dpt.reshape(dpt.arange(24),(1, 2, 3, 4)) |
| 1083 | + for source, destination, expected in [ |
| 1084 | + (0, 1, (2, 1, 3, 4)), |
| 1085 | + (1, 2, (1, 3, 2, 4)), |
| 1086 | + (1, -1, (1, 3, 4, 2)), |
| 1087 | + ]: |
| 1088 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1089 | + assert_(actual, expected) |
| 1090 | + |
| 1091 | +def test_moveaxis_preserve_order(): |
| 1092 | + x = dpt.zeros((1, 2, 3, 4)) |
| 1093 | + for source, destination in [ |
| 1094 | + (0, 0), |
| 1095 | + (3, -1), |
| 1096 | + (-1, 3), |
| 1097 | + ([0, -1], [0, -1]), |
| 1098 | + ([2, 0], [2, 0]), |
| 1099 | + ]: |
| 1100 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1101 | + assert_(actual, (1, 2, 3, 4)) |
| 1102 | + |
| 1103 | +def test_moveaxis_move_multiples(): |
| 1104 | + x = dpt.zeros((0, 1, 2, 3)) |
| 1105 | + for source, destination, expected in [ |
| 1106 | + ([0, 1], [2, 3], (2, 3, 0, 1)), |
| 1107 | + ([2, 3], [0, 1], (2, 3, 0, 1)), |
| 1108 | + ([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)), |
| 1109 | + ([3, 0], [1, 0], (0, 3, 1, 2)), |
| 1110 | + ([0, 3], [0, 1], (0, 3, 1, 2)), |
| 1111 | + ]: |
| 1112 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1113 | + assert_(actual, expected) |
| 1114 | + |
| 1115 | +def test_moveaxis_errors(): |
| 1116 | + x = dpt.reshape(dpt.arange(6),(1, 2, 3)) |
| 1117 | + assert_raises_regex(np.AxisError, 'source.*out of bounds', |
| 1118 | + dpt.moveaxis, x, 3, 0) |
| 1119 | + assert_raises_regex(np.AxisError, 'source.*out of bounds', |
| 1120 | + dpt.moveaxis, x, -4, 0) |
| 1121 | + assert_raises_regex(np.AxisError, 'destination.*out of bounds', |
| 1122 | + dpt.moveaxis, x, 0, 5) |
| 1123 | + assert_raises_regex(ValueError, 'repeated axis in `source`', |
| 1124 | + dpt.moveaxis, x, [0, 0], [0, 1]) |
| 1125 | + assert_raises_regex(ValueError, 'repeated axis in `destination`', |
| 1126 | + dpt.moveaxis, x, [0, 1], [1, 1]) |
| 1127 | + assert_raises_regex(ValueError, 'must have the same number', |
| 1128 | + dpt.moveaxis, x, 0, [0, 1]) |
| 1129 | + assert_raises_regex(ValueError, 'must have the same number', |
| 1130 | + dpt.moveaxis, x, [0, 1], [0]) |
1100 | 1131 |
|
1101 | 1132 | def test_unstack_axis0():
|
1102 | 1133 | y = dpt.reshape(dpt.arange(6), (2, 3))
|
|
0 commit comments