@@ -12,104 +12,319 @@ async def anext(iterator):
12
12
return await iterator .__anext__ ()
13
13
14
14
15
- async def map_doubles (x : int ) -> int :
15
+ async def double (x : int ) -> int :
16
+ """Test callback that doubles the input value."""
16
17
return x + x
17
18
18
19
20
+ async def throw (_x : int ) -> int :
21
+ """Test callback that raises a RuntimeError."""
22
+ raise RuntimeError ("Ouch" )
23
+
24
+
19
25
def describe_map_async_iterable ():
20
26
@mark .asyncio
21
- async def inner_is_closed_when_outer_is_closed ():
22
- class Inner :
23
- def __init__ (self ):
24
- self .closed = False
27
+ async def maps_over_async_generator ():
28
+ async def source ():
29
+ yield 1
30
+ yield 2
31
+ yield 3
25
32
26
- async def aclose (self ):
27
- self .closed = True
33
+ doubles = map_async_iterable (source (), double )
34
+
35
+ assert await anext (doubles ) == 2
36
+ assert await anext (doubles ) == 4
37
+ assert await anext (doubles ) == 6
38
+ with raises (StopAsyncIteration ):
39
+ assert await anext (doubles )
40
+
41
+ @mark .asyncio
42
+ async def maps_over_async_iterable ():
43
+ items = [1 , 2 , 3 ]
44
+
45
+ class Iterable :
46
+ def __aiter__ (self ):
47
+ return self
48
+
49
+ async def __anext__ (self ):
50
+ try :
51
+ return items .pop (0 )
52
+ except IndexError :
53
+ raise StopAsyncIteration
54
+
55
+ doubles = map_async_iterable (Iterable (), double )
56
+
57
+ values = [value async for value in doubles ]
58
+
59
+ assert not items
60
+ assert values == [2 , 4 , 6 ]
61
+
62
+ @mark .asyncio
63
+ async def compatible_with_async_for ():
64
+ async def source ():
65
+ yield 1
66
+ yield 2
67
+ yield 3
68
+
69
+ doubles = map_async_iterable (source (), double )
70
+
71
+ values = [value async for value in doubles ]
72
+
73
+ assert values == [2 , 4 , 6 ]
74
+
75
+ @mark .asyncio
76
+ async def allows_returning_early_from_mapped_async_generator ():
77
+ async def source ():
78
+ yield 1
79
+ yield 2
80
+ yield 3 # pragma: no cover
28
81
82
+ doubles = map_async_iterable (source (), double )
83
+
84
+ assert await anext (doubles ) == 2
85
+ assert await anext (doubles ) == 4
86
+
87
+ # Early return
88
+ await doubles .aclose ()
89
+
90
+ # Subsequent next calls
91
+ with raises (StopAsyncIteration ):
92
+ await anext (doubles )
93
+ with raises (StopAsyncIteration ):
94
+ await anext (doubles )
95
+
96
+ @mark .asyncio
97
+ async def allows_returning_early_from_mapped_async_iterable ():
98
+ items = [1 , 2 , 3 ]
99
+
100
+ class Iterable :
101
+ def __aiter__ (self ):
102
+ return self
103
+
104
+ async def __anext__ (self ):
105
+ try :
106
+ return items .pop (0 )
107
+ except IndexError : # pragma: no cover
108
+ raise StopAsyncIteration
109
+
110
+ doubles = map_async_iterable (Iterable (), double )
111
+
112
+ assert await anext (doubles ) == 2
113
+ assert await anext (doubles ) == 4
114
+
115
+ # Early return
116
+ await doubles .aclose ()
117
+
118
+ # Subsequent next calls
119
+ with raises (StopAsyncIteration ):
120
+ await anext (doubles )
121
+ with raises (StopAsyncIteration ):
122
+ await anext (doubles )
123
+
124
+ @mark .asyncio
125
+ async def allows_throwing_errors_through_async_iterable ():
126
+ items = [1 , 2 , 3 ]
127
+
128
+ class Iterable :
129
+ def __aiter__ (self ):
130
+ return self
131
+
132
+ async def __anext__ (self ):
133
+ try :
134
+ return items .pop (0 )
135
+ except IndexError : # pragma: no cover
136
+ raise StopAsyncIteration
137
+
138
+ doubles = map_async_iterable (Iterable (), double )
139
+
140
+ assert await anext (doubles ) == 2
141
+ assert await anext (doubles ) == 4
142
+
143
+ # Throw error
144
+ message = "allows throwing errors when mapping async iterable"
145
+ with raises (RuntimeError ) as exc_info :
146
+ await doubles .athrow (RuntimeError (message ))
147
+
148
+ assert str (exc_info .value ) == message
149
+
150
+ with raises (StopAsyncIteration ):
151
+ await anext (doubles )
152
+ with raises (StopAsyncIteration ):
153
+ await anext (doubles )
154
+
155
+ @mark .asyncio
156
+ async def allows_throwing_errors_with_values_through_async_iterables ():
157
+ class Iterable :
158
+ def __aiter__ (self ):
159
+ return self
160
+
161
+ async def __anext__ (self ):
162
+ return 1
163
+
164
+ one = map_async_iterable (Iterable (), double )
165
+
166
+ assert await anext (one ) == 2
167
+
168
+ # Throw error with value passed separately
169
+ try :
170
+ raise RuntimeError ("Ouch" )
171
+ except RuntimeError as error :
172
+ with raises (RuntimeError , match = "Ouch" ) as exc_info :
173
+ await one .athrow (error .__class__ , error )
174
+
175
+ assert exc_info .value is error
176
+ assert exc_info .tb is error .__traceback__
177
+
178
+ with raises (StopAsyncIteration ):
179
+ await anext (one )
180
+
181
+ @mark .asyncio
182
+ async def allows_throwing_errors_with_traceback_through_async_iterables ():
183
+ class Iterable :
29
184
def __aiter__ (self ):
30
185
return self
31
186
32
187
async def __anext__ (self ):
33
188
return 1
34
189
35
- inner = Inner ()
36
- outer = map_async_iterable (inner , map_doubles )
37
- iterator = outer .__aiter__ ()
38
- assert await anext (iterator ) == 2
39
- assert not inner .closed
40
- await outer .aclose ()
41
- assert inner .closed
190
+ one = map_async_iterable (Iterable (), double )
191
+
192
+ assert await anext (one ) == 2
193
+
194
+ # Throw error with traceback passed separately
195
+ try :
196
+ raise RuntimeError ("Ouch" )
197
+ except RuntimeError as error :
198
+ with raises (RuntimeError ) as exc_info :
199
+ await one .athrow (error .__class__ , None , error .__traceback__ )
200
+
201
+ assert exc_info .tb and error .__traceback__
202
+ assert exc_info .tb .tb_frame is error .__traceback__ .tb_frame
203
+
204
+ with raises (StopAsyncIteration ):
205
+ await anext (one )
42
206
43
207
@mark .asyncio
44
- async def inner_is_closed_on_callback_error ():
45
- class Inner :
208
+ async def does_not_map_over_thrown_errors ():
209
+ async def source ():
210
+ yield 1
211
+ raise RuntimeError ("Goodbye" )
212
+
213
+ doubles = map_async_iterable (source (), double )
214
+
215
+ assert await anext (doubles ) == 2
216
+
217
+ with raises (RuntimeError ) as exc_info :
218
+ await anext (doubles )
219
+
220
+ assert str (exc_info .value ) == "Goodbye"
221
+
222
+ @mark .asyncio
223
+ async def does_not_map_over_externally_thrown_errors ():
224
+ async def source ():
225
+ yield 1
226
+
227
+ doubles = map_async_iterable (source (), double )
228
+
229
+ assert await anext (doubles ) == 2
230
+
231
+ with raises (RuntimeError ) as exc_info :
232
+ await doubles .athrow (RuntimeError ("Goodbye" ))
233
+
234
+ assert str (exc_info .value ) == "Goodbye"
235
+
236
+ @mark .asyncio
237
+ async def iterable_is_closed_when_mapped_iterable_is_closed ():
238
+ class Iterable :
46
239
def __init__ (self ):
47
240
self .closed = False
48
241
242
+ def __aiter__ (self ):
243
+ return self
244
+
245
+ async def __anext__ (self ):
246
+ return 1
247
+
49
248
async def aclose (self ):
50
249
self .closed = True
51
250
251
+ iterable = Iterable ()
252
+ doubles = map_async_iterable (iterable , double )
253
+ assert await anext (doubles ) == 2
254
+ assert not iterable .closed
255
+ await doubles .aclose ()
256
+ assert iterable .closed
257
+ with raises (StopAsyncIteration ):
258
+ await anext (doubles )
259
+
260
+ @mark .asyncio
261
+ async def iterable_is_closed_on_callback_error ():
262
+ class Iterable :
263
+ def __init__ (self ):
264
+ self .closed = False
265
+
52
266
def __aiter__ (self ):
53
267
return self
54
268
55
269
async def __anext__ (self ):
56
270
return 1
57
271
58
- async def callback ( v ):
59
- raise RuntimeError ()
272
+ async def aclose ( self ):
273
+ self . closed = True
60
274
61
- inner = Inner ()
62
- outer = map_async_iterable (inner , callback )
63
- with raises (RuntimeError ):
64
- await anext (outer )
65
- assert inner .closed
275
+ iterable = Iterable ()
276
+ doubles = map_async_iterable (iterable , throw )
277
+ with raises (RuntimeError , match = "Ouch" ):
278
+ await anext (doubles )
279
+ assert iterable .closed
280
+ with raises (StopAsyncIteration ):
281
+ await anext (doubles )
66
282
67
283
@mark .asyncio
68
- async def test_inner_exits_on_callback_error ():
69
- inner_exit = False
284
+ async def iterable_exits_on_callback_error ():
285
+ exited = False
70
286
71
- async def inner ():
72
- nonlocal inner_exit
287
+ async def iterable ():
288
+ nonlocal exited
73
289
try :
74
290
while True :
75
291
yield 1
76
292
except GeneratorExit :
77
- inner_exit = True
293
+ exited = True
78
294
79
- async def callback (v ):
80
- raise RuntimeError
81
-
82
- outer = map_async_iterable (inner (), callback )
83
- with raises (RuntimeError ):
84
- await anext (outer )
85
- assert inner_exit
295
+ doubles = map_async_iterable (iterable (), throw )
296
+ with raises (RuntimeError , match = "Ouch" ):
297
+ await anext (doubles )
298
+ assert exited
299
+ with raises (StopAsyncIteration ):
300
+ await anext (doubles )
86
301
87
302
@mark .asyncio
88
- async def inner_has_no_close_method_when_outer_is_closed ():
89
- class Inner :
303
+ async def mapped_iterable_is_closed_when_iterable_cannot_be_closed ():
304
+ class Iterable :
90
305
def __aiter__ (self ):
91
306
return self
92
307
93
308
async def __anext__ (self ):
94
309
return 1
95
310
96
- outer = map_async_iterable (Inner (), map_doubles )
97
- iterator = outer .__aiter__ ()
98
- assert await anext (iterator ) == 2
99
- await outer .aclose ()
311
+ doubles = map_async_iterable (Iterable (), double )
312
+ assert await anext (doubles ) == 2
313
+ await doubles .aclose ()
314
+ with raises (StopAsyncIteration ):
315
+ await anext (doubles )
100
316
101
317
@mark .asyncio
102
- async def inner_has_no_close_method_on_callback_error ():
103
- class Inner :
318
+ async def ignores_that_iterable_cannot_be_closed_on_callback_error ():
319
+ class Iterable :
104
320
def __aiter__ (self ):
105
321
return self
106
322
107
323
async def __anext__ (self ):
108
324
return 1
109
325
110
- async def callback (v ):
111
- raise RuntimeError ()
112
-
113
- outer = map_async_iterable (Inner (), callback )
114
- with raises (RuntimeError ):
115
- await anext (outer )
326
+ doubles = map_async_iterable (Iterable (), throw )
327
+ with raises (RuntimeError , match = "Ouch" ):
328
+ await anext (doubles )
329
+ with raises (StopAsyncIteration ):
330
+ await anext (doubles )
0 commit comments