Skip to content

gh-119793: Add optional length-checking to map() #120471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
06fdd71
Add `strict` argument
nineteendo Jun 13, 2024
a7966f8
📜🤖 Added by blurb_it.
blurb-it[bot] Jun 13, 2024
89e25d9
Add tests
nineteendo Jun 13, 2024
a48ebcf
Update commented out code
nineteendo Jun 13, 2024
e4ee2fb
Update docs
nineteendo Jun 13, 2024
bdd0fcb
Fix `__reduce__()`
nineteendo Jun 13, 2024
91a8450
Revert `map_vectorcall()`
nineteendo Jun 14, 2024
2c7d524
Reduce diff
nineteendo Jun 14, 2024
5ad7eb2
Never set
nineteendo Jun 14, 2024
6d8584b
Reduce diff 2
nineteendo Jun 14, 2024
e436a2b
Fixed undefined variable
nineteendo Jun 14, 2024
f9a2a49
Apply suggestions from code review
nineteendo Jun 14, 2024
800ac10
Merge branch 'main' into strict-map
nineteendo Jun 19, 2024
301da5d
Update 2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst
nineteendo Jun 19, 2024
fb1c379
Update whatsnew
nineteendo Jun 19, 2024
255e3a2
Fix pr number
nineteendo Jun 19, 2024
d1c1769
Fix signature
nineteendo Jun 19, 2024
6d20976
Add comment
nineteendo Jun 19, 2024
c338c8f
Add comment 2
nineteendo Jun 19, 2024
99d9c75
Fix typo
nineteendo Jun 19, 2024
63522b3
Update Doc/library/functions.rst
nineteendo Jun 19, 2024
fcb438f
Match news entry of `zip()`
nineteendo Jun 19, 2024
6e97e31
Fix trailing whitespace
nineteendo Jun 19, 2024
54231e0
Update 2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst
nineteendo Jun 24, 2024
b5e8ac4
Update 2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst
nineteendo Jun 24, 2024
53f3f58
Remove pep reference
nineteendo Sep 4, 2024
a2fe008
Merge branch 'main' into strict-map
nineteendo Sep 28, 2024
9a2c0fd
Make tests more maintainable
nineteendo Oct 3, 2024
f75e511
Update Python/bltinmodule.c
nineteendo Oct 10, 2024
92134e7
Update Python/bltinmodule.c
nineteendo Oct 10, 2024
e74c3ff
Use correct variable
nineteendo Oct 10, 2024
bf7f350
Update Python/bltinmodule.c
nineteendo Oct 31, 2024
9cc41f2
Rename 2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst to 2024-06-13-1…
nineteendo Oct 31, 2024
f686101
Fix shadowed variable
nineteendo Oct 31, 2024
e524e17
Apply suggestions from code review
nineteendo Oct 31, 2024
91bdf2f
Rename 2024-06-13-19-12-49.gh-issue-119793.FDVCDk.rst to 2024-06-13-1…
nineteendo Oct 31, 2024
39c34f1
Use same message for news.d
nineteendo Oct 31, 2024
ea65b2c
Update Doc/whatsnew/3.14.rst
nineteendo Oct 31, 2024
45c2d0d
Apply suggestions from code review
nineteendo Oct 31, 2024
ad6c2a5
Sync with what's new
nineteendo Oct 31, 2024
9ebc398
Accept truthy value in setstate
nineteendo Oct 31, 2024
d829a17
Fix removed line
nineteendo Oct 31, 2024
d1d858f
Use stdbool
nineteendo Oct 31, 2024
f23caa5
Revert "Use stdbool"
nineteendo Oct 31, 2024
7e32e2c
Apply suggestions from code review
nineteendo Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions Doc/library/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1192,14 +1192,19 @@ are always available. They are listed here in alphabetical order.
unchanged from previous versions.


.. function:: map(function, iterable, *iterables)
.. function:: map(function, iterable, /, *iterables, strict=False)

Return an iterator that applies *function* to every item of *iterable*,
yielding the results. If additional *iterables* arguments are passed,
*function* must take that many arguments and is applied to the items from all
iterables in parallel. With multiple iterables, the iterator stops when the
shortest iterable is exhausted. For cases where the function inputs are
already arranged into argument tuples, see :func:`itertools.starmap`\.
shortest iterable is exhausted. If *strict* is ``True`` and one of the
iterables is exhausted before the others, a :exc:`ValueError` is raised. For
cases where the function inputs are already arranged into argument tuples,
see :func:`itertools.starmap`.

.. versionchanged:: 3.14
Added the *strict* parameter.


.. function:: max(iterable, *, key=None)
Expand Down
4 changes: 4 additions & 0 deletions Doc/whatsnew/3.14.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ Improved Error Messages
Other Language Changes
======================

* The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
(Contributed by Wannes Boeykens in :gh:`119793`.)

* Incorrect usage of :keyword:`await` and asynchronous comprehensions
is now detected even if the code is optimized away by the :option:`-O`
command line option. For example, ``python -O -c 'assert await 1'``
Expand Down
105 changes: 105 additions & 0 deletions Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def filter_char(arg):
def map_char(arg):
return chr(ord(arg)+1)

def pack(*args):
return args

class BuiltinTest(unittest.TestCase):
# Helper to check picklability
def check_iter_pickle(self, it, seq, proto):
Expand Down Expand Up @@ -1269,6 +1272,108 @@ def test_map_pickle(self):
m2 = map(map_char, "Is this the real life?")
self.check_iter_pickle(m1, list(m2), proto)

# strict map tests based on strict zip tests

def test_map_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
self.check_iter_pickle(m1, t, proto)

def test_map_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
m2 = pickle.loads(pickle.dumps(m1, proto))
self.assertEqual(self.iter_error(m1, ValueError), t)
self.assertEqual(self.iter_error(m2, ValueError), t)

def test_map_strict(self):
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), (1, 2), 'abc', strict=True))

def test_map_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(map(pack, x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)

def test_map_strict_error_handling(self):

class Error(Exception):
pass

class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size

l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])

def test_map_strict_error_handling_stopiteration(self):

class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size

l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])

def test_max(self):
self.assertEqual(max('123123'), '3')
self.assertEqual(max(1, 2, 3), 3)
Expand Down
4 changes: 2 additions & 2 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,10 +2421,10 @@ class subclass(cls):
subclass(*args, newarg=3)

for cls, args, result in testcases:
# Constructors of repeat, zip, compress accept keyword arguments.
# Constructors of repeat, zip, map, compress accept keyword arguments.
# Their subclasses need overriding __new__ to support new
# keyword arguments.
if cls in [repeat, zip, compress]:
if cls in [repeat, zip, map, compress]:
continue
with self.subTest(cls):
class subclass_with_init(cls):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The :func:`map` built-in now has an optional keyword-only *strict* flag
like :func:`zip` to check that all the iterables are of equal length.
Patch by Wannes Boeykens.
100 changes: 88 additions & 12 deletions Python/bltinmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@ typedef struct {
PyObject_HEAD
PyObject *iters;
PyObject *func;
int strict;
} mapobject;

static PyObject *
Expand All @@ -1311,10 +1312,21 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *it, *iters, *func;
mapobject *lz;
Py_ssize_t numargs, i;
int strict = 0;

if ((type == &PyMap_Type || type->tp_init == PyMap_Type.tp_init) &&
!_PyArg_NoKeywords("map", kwds))
return NULL;
if (kwds) {
PyObject *empty = PyTuple_New(0);
if (empty == NULL) {
return NULL;
}
static char *kwlist[] = {"strict", NULL};
int parsed = PyArg_ParseTupleAndKeywords(
empty, kwds, "|$p:map", kwlist, &strict);
Py_DECREF(empty);
if (!parsed) {
return NULL;
}
}

numargs = PyTuple_Size(args);
if (numargs < 2) {
Expand Down Expand Up @@ -1346,6 +1358,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
lz->iters = iters;
func = PyTuple_GET_ITEM(args, 0);
lz->func = Py_NewRef(func);
lz->strict = strict;

return (PyObject *)lz;
}
Expand All @@ -1355,11 +1368,14 @@ map_vectorcall(PyObject *type, PyObject * const*args,
size_t nargsf, PyObject *kwnames)
{
PyTypeObject *tp = _PyType_CAST(type);
if (tp == &PyMap_Type && !_PyArg_NoKwnames("map", kwnames)) {
return NULL;
}

Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) {
// Fallback to map_new()
PyThreadState *tstate = _PyThreadState_GET();
return _PyObject_MakeTpCall(tstate, type, args, nargs, kwnames);
}

if (nargs < 2) {
PyErr_SetString(PyExc_TypeError,
"map() must have at least two arguments.");
Expand Down Expand Up @@ -1387,6 +1403,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
}
lz->iters = iters;
lz->func = Py_NewRef(args[0]);
lz->strict = 0;

return (PyObject *)lz;
}
Expand All @@ -1411,6 +1428,7 @@ map_traverse(mapobject *lz, visitproc visit, void *arg)
static PyObject *
map_next(mapobject *lz)
{
Py_ssize_t i;
PyObject *small_stack[_PY_FASTCALL_SMALL_STACK];
PyObject **stack;
PyObject *result = NULL;
Expand All @@ -1429,10 +1447,13 @@ map_next(mapobject *lz)
}

Py_ssize_t nargs = 0;
for (Py_ssize_t i=0; i < niters; i++) {
for (i=0; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = Py_TYPE(it)->tp_iternext(it);
if (val == NULL) {
if (lz->strict) {
goto check;
}
goto exit;
}
stack[i] = val;
Expand All @@ -1442,13 +1463,50 @@ map_next(mapobject *lz)
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);

exit:
for (Py_ssize_t i=0; i < nargs; i++) {
for (i=0; i < nargs; i++) {
Py_DECREF(stack[i]);
}
if (stack != small_stack) {
PyMem_Free(stack);
}
return result;
check:
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
if (i) {
// ValueError: map() argument 2 is shorter than argument 1
// ValueError: map() argument 3 is shorter than arguments 1-2
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is shorter than argument%s%d",
i + 1, plural, i);
}
for (i = 1; i < niters; i++) {
PyObject *it = PyTuple_GET_ITEM(lz->iters, i);
PyObject *val = (*Py_TYPE(it)->tp_iternext)(it);
if (val) {
Py_DECREF(val);
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"map() argument %d is longer than argument%s%d",
i + 1, plural, i);
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
// Argument i is exhausted. So far so good...
}
// All arguments are exhausted. Success!
goto exit;
}

static PyObject *
Expand All @@ -1465,21 +1523,41 @@ map_reduce(mapobject *lz, PyObject *Py_UNUSED(ignored))
PyTuple_SET_ITEM(args, i+1, Py_NewRef(it));
}

if (lz->strict) {
return Py_BuildValue("ONO", Py_TYPE(lz), args, Py_True);
}
return Py_BuildValue("ON", Py_TYPE(lz), args);
}

PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");

static PyObject *
map_setstate(mapobject *lz, PyObject *state)
{
int strict = PyObject_IsTrue(state);
if (strict < 0) {
return NULL;
}
lz->strict = strict;
Py_RETURN_NONE;
}

static PyMethodDef map_methods[] = {
{"__reduce__", _PyCFunction_CAST(map_reduce), METH_NOARGS, reduce_doc},
{"__setstate__", _PyCFunction_CAST(map_setstate), METH_O, setstate_doc},
{NULL, NULL} /* sentinel */
};


PyDoc_STRVAR(map_doc,
"map(function, iterable, /, *iterables)\n\
"map(function, iterable, /, *iterables, strict=False)\n\
--\n\
\n\
Make an iterator that computes the function using arguments from\n\
each of the iterables. Stops when the shortest iterable is exhausted.");
each of the iterables. Stops when the shortest iterable is exhausted.\n\
\n\
If strict is true and one of the arguments is exhausted before the others,\n\
raise a ValueError.");

PyTypeObject PyMap_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
Expand Down Expand Up @@ -3060,8 +3138,6 @@ zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
}

PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");

static PyObject *
zip_setstate(zipobject *lz, PyObject *state)
{
Expand Down
Loading