Skip to content

Commit 56ea5dd

Browse files
committed
Add some more tests for map_async_iterable
Add back some of the original tests and use similar names in tests.
1 parent 1e77b8c commit 56ea5dd

File tree

2 files changed

+266
-51
lines changed

2 files changed

+266
-51
lines changed

src/graphql/execution/async_iterables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ async def map_async_iterable(
6262
) -> AsyncGenerator[V, None]:
6363
"""Map an AsyncIterable over a callback function.
6464
65-
Given an AsyncIterable and an async callback callable, return an AsyncGenerator
66-
which produces values mapped via calling the callback.
65+
Given an AsyncIterable and an async callback function, return an AsyncGenerator
66+
that produces values mapped via calling the callback function.
6767
If the inner iterator supports an `aclose()` method, it will be called when
6868
the generator finishes or closes.
6969
"""

tests/execution/test_map_async_iterable.py

Lines changed: 264 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,104 +12,319 @@ async def anext(iterator):
1212
return await iterator.__anext__()
1313

1414

15-
async def map_doubles(x: int) -> int:
15+
async def double(x: int) -> int:
16+
"""Test callback that doubles the input value."""
1617
return x + x
1718

1819

20+
async def throw(_x: int) -> int:
21+
"""Test callback that raises a RuntimeError."""
22+
raise RuntimeError("Ouch")
23+
24+
1925
def describe_map_async_iterable():
2026
@mark.asyncio
21-
async def inner_is_closed_when_outer_is_closed():
22-
class Inner:
23-
def __init__(self):
24-
self.closed = False
27+
async def maps_over_async_generator():
28+
async def source():
29+
yield 1
30+
yield 2
31+
yield 3
2532

26-
async def aclose(self):
27-
self.closed = True
33+
doubles = map_async_iterable(source(), double)
34+
35+
assert await anext(doubles) == 2
36+
assert await anext(doubles) == 4
37+
assert await anext(doubles) == 6
38+
with raises(StopAsyncIteration):
39+
assert await anext(doubles)
40+
41+
@mark.asyncio
42+
async def maps_over_async_iterable():
43+
items = [1, 2, 3]
44+
45+
class Iterable:
46+
def __aiter__(self):
47+
return self
48+
49+
async def __anext__(self):
50+
try:
51+
return items.pop(0)
52+
except IndexError:
53+
raise StopAsyncIteration
54+
55+
doubles = map_async_iterable(Iterable(), double)
56+
57+
values = [value async for value in doubles]
58+
59+
assert not items
60+
assert values == [2, 4, 6]
61+
62+
@mark.asyncio
63+
async def compatible_with_async_for():
64+
async def source():
65+
yield 1
66+
yield 2
67+
yield 3
68+
69+
doubles = map_async_iterable(source(), double)
70+
71+
values = [value async for value in doubles]
72+
73+
assert values == [2, 4, 6]
74+
75+
@mark.asyncio
76+
async def allows_returning_early_from_mapped_async_generator():
77+
async def source():
78+
yield 1
79+
yield 2
80+
yield 3 # pragma: no cover
2881

82+
doubles = map_async_iterable(source(), double)
83+
84+
assert await anext(doubles) == 2
85+
assert await anext(doubles) == 4
86+
87+
# Early return
88+
await doubles.aclose()
89+
90+
# Subsequent next calls
91+
with raises(StopAsyncIteration):
92+
await anext(doubles)
93+
with raises(StopAsyncIteration):
94+
await anext(doubles)
95+
96+
@mark.asyncio
97+
async def allows_returning_early_from_mapped_async_iterable():
98+
items = [1, 2, 3]
99+
100+
class Iterable:
101+
def __aiter__(self):
102+
return self
103+
104+
async def __anext__(self):
105+
try:
106+
return items.pop(0)
107+
except IndexError: # pragma: no cover
108+
raise StopAsyncIteration
109+
110+
doubles = map_async_iterable(Iterable(), double)
111+
112+
assert await anext(doubles) == 2
113+
assert await anext(doubles) == 4
114+
115+
# Early return
116+
await doubles.aclose()
117+
118+
# Subsequent next calls
119+
with raises(StopAsyncIteration):
120+
await anext(doubles)
121+
with raises(StopAsyncIteration):
122+
await anext(doubles)
123+
124+
@mark.asyncio
125+
async def allows_throwing_errors_through_async_iterable():
126+
items = [1, 2, 3]
127+
128+
class Iterable:
129+
def __aiter__(self):
130+
return self
131+
132+
async def __anext__(self):
133+
try:
134+
return items.pop(0)
135+
except IndexError: # pragma: no cover
136+
raise StopAsyncIteration
137+
138+
doubles = map_async_iterable(Iterable(), double)
139+
140+
assert await anext(doubles) == 2
141+
assert await anext(doubles) == 4
142+
143+
# Throw error
144+
message = "allows throwing errors when mapping async iterable"
145+
with raises(RuntimeError) as exc_info:
146+
await doubles.athrow(RuntimeError(message))
147+
148+
assert str(exc_info.value) == message
149+
150+
with raises(StopAsyncIteration):
151+
await anext(doubles)
152+
with raises(StopAsyncIteration):
153+
await anext(doubles)
154+
155+
@mark.asyncio
156+
async def allows_throwing_errors_with_values_through_async_iterables():
157+
class Iterable:
158+
def __aiter__(self):
159+
return self
160+
161+
async def __anext__(self):
162+
return 1
163+
164+
one = map_async_iterable(Iterable(), double)
165+
166+
assert await anext(one) == 2
167+
168+
# Throw error with value passed separately
169+
try:
170+
raise RuntimeError("Ouch")
171+
except RuntimeError as error:
172+
with raises(RuntimeError, match="Ouch") as exc_info:
173+
await one.athrow(error.__class__, error)
174+
175+
assert exc_info.value is error
176+
assert exc_info.tb is error.__traceback__
177+
178+
with raises(StopAsyncIteration):
179+
await anext(one)
180+
181+
@mark.asyncio
182+
async def allows_throwing_errors_with_traceback_through_async_iterables():
183+
class Iterable:
29184
def __aiter__(self):
30185
return self
31186

32187
async def __anext__(self):
33188
return 1
34189

35-
inner = Inner()
36-
outer = map_async_iterable(inner, map_doubles)
37-
iterator = outer.__aiter__()
38-
assert await anext(iterator) == 2
39-
assert not inner.closed
40-
await outer.aclose()
41-
assert inner.closed
190+
one = map_async_iterable(Iterable(), double)
191+
192+
assert await anext(one) == 2
193+
194+
# Throw error with traceback passed separately
195+
try:
196+
raise RuntimeError("Ouch")
197+
except RuntimeError as error:
198+
with raises(RuntimeError) as exc_info:
199+
await one.athrow(error.__class__, None, error.__traceback__)
200+
201+
assert exc_info.tb and error.__traceback__
202+
assert exc_info.tb.tb_frame is error.__traceback__.tb_frame
203+
204+
with raises(StopAsyncIteration):
205+
await anext(one)
42206

43207
@mark.asyncio
44-
async def inner_is_closed_on_callback_error():
45-
class Inner:
208+
async def does_not_map_over_thrown_errors():
209+
async def source():
210+
yield 1
211+
raise RuntimeError("Goodbye")
212+
213+
doubles = map_async_iterable(source(), double)
214+
215+
assert await anext(doubles) == 2
216+
217+
with raises(RuntimeError) as exc_info:
218+
await anext(doubles)
219+
220+
assert str(exc_info.value) == "Goodbye"
221+
222+
@mark.asyncio
223+
async def does_not_map_over_externally_thrown_errors():
224+
async def source():
225+
yield 1
226+
227+
doubles = map_async_iterable(source(), double)
228+
229+
assert await anext(doubles) == 2
230+
231+
with raises(RuntimeError) as exc_info:
232+
await doubles.athrow(RuntimeError("Goodbye"))
233+
234+
assert str(exc_info.value) == "Goodbye"
235+
236+
@mark.asyncio
237+
async def iterable_is_closed_when_mapped_iterable_is_closed():
238+
class Iterable:
46239
def __init__(self):
47240
self.closed = False
48241

242+
def __aiter__(self):
243+
return self
244+
245+
async def __anext__(self):
246+
return 1
247+
49248
async def aclose(self):
50249
self.closed = True
51250

251+
iterable = Iterable()
252+
doubles = map_async_iterable(iterable, double)
253+
assert await anext(doubles) == 2
254+
assert not iterable.closed
255+
await doubles.aclose()
256+
assert iterable.closed
257+
with raises(StopAsyncIteration):
258+
await anext(doubles)
259+
260+
@mark.asyncio
261+
async def iterable_is_closed_on_callback_error():
262+
class Iterable:
263+
def __init__(self):
264+
self.closed = False
265+
52266
def __aiter__(self):
53267
return self
54268

55269
async def __anext__(self):
56270
return 1
57271

58-
async def callback(v):
59-
raise RuntimeError()
272+
async def aclose(self):
273+
self.closed = True
60274

61-
inner = Inner()
62-
outer = map_async_iterable(inner, callback)
63-
with raises(RuntimeError):
64-
await anext(outer)
65-
assert inner.closed
275+
iterable = Iterable()
276+
doubles = map_async_iterable(iterable, throw)
277+
with raises(RuntimeError, match="Ouch"):
278+
await anext(doubles)
279+
assert iterable.closed
280+
with raises(StopAsyncIteration):
281+
await anext(doubles)
66282

67283
@mark.asyncio
68-
async def test_inner_exits_on_callback_error():
69-
inner_exit = False
284+
async def iterable_exits_on_callback_error():
285+
exited = False
70286

71-
async def inner():
72-
nonlocal inner_exit
287+
async def iterable():
288+
nonlocal exited
73289
try:
74290
while True:
75291
yield 1
76292
except GeneratorExit:
77-
inner_exit = True
293+
exited = True
78294

79-
async def callback(v):
80-
raise RuntimeError
81-
82-
outer = map_async_iterable(inner(), callback)
83-
with raises(RuntimeError):
84-
await anext(outer)
85-
assert inner_exit
295+
doubles = map_async_iterable(iterable(), throw)
296+
with raises(RuntimeError, match="Ouch"):
297+
await anext(doubles)
298+
assert exited
299+
with raises(StopAsyncIteration):
300+
await anext(doubles)
86301

87302
@mark.asyncio
88-
async def inner_has_no_close_method_when_outer_is_closed():
89-
class Inner:
303+
async def mapped_iterable_is_closed_when_iterable_cannot_be_closed():
304+
class Iterable:
90305
def __aiter__(self):
91306
return self
92307

93308
async def __anext__(self):
94309
return 1
95310

96-
outer = map_async_iterable(Inner(), map_doubles)
97-
iterator = outer.__aiter__()
98-
assert await anext(iterator) == 2
99-
await outer.aclose()
311+
doubles = map_async_iterable(Iterable(), double)
312+
assert await anext(doubles) == 2
313+
await doubles.aclose()
314+
with raises(StopAsyncIteration):
315+
await anext(doubles)
100316

101317
@mark.asyncio
102-
async def inner_has_no_close_method_on_callback_error():
103-
class Inner:
318+
async def ignores_that_iterable_cannot_be_closed_on_callback_error():
319+
class Iterable:
104320
def __aiter__(self):
105321
return self
106322

107323
async def __anext__(self):
108324
return 1
109325

110-
async def callback(v):
111-
raise RuntimeError()
112-
113-
outer = map_async_iterable(Inner(), callback)
114-
with raises(RuntimeError):
115-
await anext(outer)
326+
doubles = map_async_iterable(Iterable(), throw)
327+
with raises(RuntimeError, match="Ouch"):
328+
await anext(doubles)
329+
with raises(StopAsyncIteration):
330+
await anext(doubles)

0 commit comments

Comments
 (0)