Tutorial: Comprehensions, Iterators, and Iterables

Author: Florent Hivert <florent.hivert@univ-rouen.fr> and Nicolas M. Thiéry <nthiery at users.sf.net>

List comprehensions

List comprehensions are a very handy way to construct lists in Python. You can use either of the following idioms:

[ <expr> for <name> in <iterable> ]
[ <expr> for <name> in <iterable> if <condition> ]

For example, here are some lists of squares:

sage: [ i^2 for i in [1, 3, 7] ]
[1, 9, 49]
sage: [ i^2 for i in range(1,10) ]
[1, 4, 9, 16, 25, 36, 49, 64, 81]
sage: [ i^2 for i in range(1,10) if i % 2 == 1]
[1, 9, 25, 49, 81]
>>> from sage.all import *
>>> [ i**Integer(2) for i in [Integer(1), Integer(3), Integer(7)] ]
[1, 9, 49]
>>> [ i**Integer(2) for i in range(Integer(1),Integer(10)) ]
[1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> [ i**Integer(2) for i in range(Integer(1),Integer(10)) if i % Integer(2) == Integer(1)]
[1, 9, 25, 49, 81]
[ i^2 for i in [1, 3, 7] ]
[ i^2 for i in range(1,10) ]
[ i^2 for i in range(1,10) if i % 2 == 1]

And a variant on the latter:

sage: [i^2 if i % 2 == 1 else 2 for i in range(10)]
[2, 1, 2, 9, 2, 25, 2, 49, 2, 81]
>>> from sage.all import *
>>> [i**Integer(2) if i % Integer(2) == Integer(1) else Integer(2) for i in range(Integer(10))]
[2, 1, 2, 9, 2, 25, 2, 49, 2, 81]
[i^2 if i % 2 == 1 else 2 for i in range(10)]

One can use more than one iterable in a list comprehension:

sage: [ (i,j) for i in range(1,6) for j in range(1,i) ]
[(2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (4, 3), (5, 1), (5, 2), (5, 3), (5, 4)]
>>> from sage.all import *
>>> [ (i,j) for i in range(Integer(1),Integer(6)) for j in range(Integer(1),i) ]
[(2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (4, 3), (5, 1), (5, 2), (5, 3), (5, 4)]
[ (i,j) for i in range(1,6) for j in range(1,i) ]

Warning

Mind the order of the nested loop in the previous expression.

If instead one wants to build a list of lists, one can use nested lists as in:

sage: [ [ binomial(n, i) for i in range(n+1) ] for n in range(10) ]
[[1],
[1, 1],
[1, 2, 1],
[1, 3, 3, 1],
[1, 4, 6, 4, 1],
[1, 5, 10, 10, 5, 1],
[1, 6, 15, 20, 15, 6, 1],
[1, 7, 21, 35, 35, 21, 7, 1],
[1, 8, 28, 56, 70, 56, 28, 8, 1],
[1, 9, 36, 84, 126, 126, 84, 36, 9, 1]]
>>> from sage.all import *
>>> [ [ binomial(n, i) for i in range(n+Integer(1)) ] for n in range(Integer(10)) ]
[[1],
[1, 1],
[1, 2, 1],
[1, 3, 3, 1],
[1, 4, 6, 4, 1],
[1, 5, 10, 10, 5, 1],
[1, 6, 15, 20, 15, 6, 1],
[1, 7, 21, 35, 35, 21, 7, 1],
[1, 8, 28, 56, 70, 56, 28, 8, 1],
[1, 9, 36, 84, 126, 126, 84, 36, 9, 1]]
[ [ binomial(n, i) for i in range(n+1) ] for n in range(10) ]

Iterators

Definition

To build a comprehension, Python actually uses an iterator. This is a device which runs through a bunch of objects, returning one at each call to the next method. Iterators are built using parentheses:

sage: it = (binomial(8, i) for i in range(9))
sage: next(it)
1
>>> from sage.all import *
>>> it = (binomial(Integer(8), i) for i in range(Integer(9)))
>>> next(it)
1
it = (binomial(8, i) for i in range(9))
next(it)
sage: next(it)
8
sage: next(it)
28
sage: next(it)
56
>>> from sage.all import *
>>> next(it)
8
>>> next(it)
28
>>> next(it)
56
next(it)
next(it)
next(it)
>>> from sage.all import *
>>> next(it)
8
>>> next(it)
28
>>> next(it)
56
next(it)
next(it)
next(it)

You can get the list of the results that are not yet consumed:

sage: list(it)
[70, 56, 28, 8, 1]
>>> from sage.all import *
>>> list(it)
[70, 56, 28, 8, 1]
list(it)

Asking for more elements triggers a StopIteration exception:

sage: next(it)
Traceback (most recent call last):
...
StopIteration
>>> from sage.all import *
>>> next(it)
Traceback (most recent call last):
...
StopIteration
next(it)

An iterator can be used as argument for a function. The two following idioms give the same results; however, the second idiom is much more memory efficient (for large examples) as it does not expand any list in memory:

sage: sum([binomial(8, i) for i in range(9)])
256
sage: sum(binomial(8, i) for i in range(9))
256
>>> from sage.all import *
>>> sum([binomial(Integer(8), i) for i in range(Integer(9))])
256
>>> sum(binomial(Integer(8), i) for i in range(Integer(9)))
256
sum([binomial(8, i) for i in range(9)])
sum(binomial(8, i) for i in range(9))

Typical usage of iterators

Iterators are very handy with the functions all(), any(), and exists():

sage: all([True, True, True, True])
True
sage: all([True, False, True, True])
False
>>> from sage.all import *
>>> all([True, True, True, True])
True
>>> all([True, False, True, True])
False
all([True, True, True, True])
all([True, False, True, True])
sage: any([False, False, False, False])
False
sage: any([False, False, True, False])
True
>>> from sage.all import *
>>> any([False, False, False, False])
False
>>> any([False, False, True, False])
True
any([False, False, False, False])
any([False, False, True, False])
>>> from sage.all import *
>>> any([False, False, False, False])
False
>>> any([False, False, True, False])
True
any([False, False, False, False])
any([False, False, True, False])

Let’s check that all the prime numbers larger than 2 are odd:

sage: all( is_odd(p) for p in range(1,100) if is_prime(p) and p>2 )
True
>>> from sage.all import *
>>> all( is_odd(p) for p in range(Integer(1),Integer(100)) if is_prime(p) and p>Integer(2) )
True
all( is_odd(p) for p in range(1,100) if is_prime(p) and p>2 )

It is well know that if 2^p-1 is prime then p is prime:

sage: def mersenne(p): return 2^p -1
sage: [ is_prime(p) for p in range(20) if is_prime(mersenne(p)) ]
[True, True, True, True, True, True, True]
>>> from sage.all import *
>>> def mersenne(p): return Integer(2)**p -Integer(1)
>>> [ is_prime(p) for p in range(Integer(20)) if is_prime(mersenne(p)) ]
[True, True, True, True, True, True, True]
def mersenne(p): return 2^p -1
[ is_prime(p) for p in range(20) if is_prime(mersenne(p)) ]

The converse is not true:

sage: all( is_prime(mersenne(p)) for p in range(1000) if is_prime(p) )
False
>>> from sage.all import *
>>> all( is_prime(mersenne(p)) for p in range(Integer(1000)) if is_prime(p) )
False
all( is_prime(mersenne(p)) for p in range(1000) if is_prime(p) )

Using a list would be much slower here:

sage: %time all( is_prime(mersenne(p)) for p in range(1000) if is_prime(p) )    # not tested
CPU times: user 0.00 s, sys: 0.00 s, total: 0.00 s
Wall time: 0.00 s
False
sage: %time all( [ is_prime(mersenne(p)) for p in range(1000) if is_prime(p)] ) # not tested
CPU times: user 0.72 s, sys: 0.00 s, total: 0.73 s
Wall time: 0.73 s
False
>>> from sage.all import *
>>> %time all( is_prime(mersenne(p)) for p in range(Integer(1000)) if is_prime(p) )    # not tested
CPU times: user 0.00 s, sys: 0.00 s, total: 0.00 s
Wall time: 0.00 s
False
>>> %time all( [ is_prime(mersenne(p)) for p in range(Integer(1000)) if is_prime(p)] ) # not tested
CPU times: user 0.72 s, sys: 0.00 s, total: 0.73 s
Wall time: 0.73 s
False
%time all( is_prime(mersenne(p)) for p in range(1000) if is_prime(p) )    # not tested
%time all( [ is_prime(mersenne(p)) for p in range(1000) if is_prime(p)] ) # not tested

You can get the counterexample using exists(). It takes two arguments: an iterator and a function which tests the property that should hold:

sage: exists( (p for p in range(1000) if is_prime(p)), lambda p: not is_prime(mersenne(p)) )
(True, 11)
>>> from sage.all import *
>>> exists( (p for p in range(Integer(1000)) if is_prime(p)), lambda p: not is_prime(mersenne(p)) )
(True, 11)
exists( (p for p in range(1000) if is_prime(p)), lambda p: not is_prime(mersenne(p)) )

An alternative way to achieve this is:

sage: counter_examples = (p for p in range(1000) if is_prime(p) and not is_prime(mersenne(p)))
sage: next(counter_examples)
11
>>> from sage.all import *
>>> counter_examples = (p for p in range(Integer(1000)) if is_prime(p) and not is_prime(mersenne(p)))
>>> next(counter_examples)
11
counter_examples = (p for p in range(1000) if is_prime(p) and not is_prime(mersenne(p)))
next(counter_examples)

itertools

At its name suggests itertools is a module which defines several handy tools for manipulating iterators:

sage: l = [3, 234, 12, 53, 23]
sage: [(i, l[i]) for i in range(len(l))]
[(0, 3), (1, 234), (2, 12), (3, 53), (4, 23)]
>>> from sage.all import *
>>> l = [Integer(3), Integer(234), Integer(12), Integer(53), Integer(23)]
>>> [(i, l[i]) for i in range(len(l))]
[(0, 3), (1, 234), (2, 12), (3, 53), (4, 23)]
l = [3, 234, 12, 53, 23]
[(i, l[i]) for i in range(len(l))]

The same results can be obtained using enumerate():

sage: list(enumerate(l))
[(0, 3), (1, 234), (2, 12), (3, 53), (4, 23)]
>>> from sage.all import *
>>> list(enumerate(l))
[(0, 3), (1, 234), (2, 12), (3, 53), (4, 23)]
list(enumerate(l))

Here is the analogue of list slicing:

sage: list(Permutations(3))
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]]
sage: list(Permutations(3))[1:4]
[[1, 3, 2], [2, 1, 3], [2, 3, 1]]

sage: import itertools
sage: list(itertools.islice(Permutations(3), 1r, 4r))
[[1, 3, 2], [2, 1, 3], [2, 3, 1]]
>>> from sage.all import *
>>> list(Permutations(Integer(3)))
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]]
>>> list(Permutations(Integer(3)))[Integer(1):Integer(4)]
[[1, 3, 2], [2, 1, 3], [2, 3, 1]]

>>> import itertools
>>> list(itertools.islice(Permutations(Integer(3)), 1, 4))
[[1, 3, 2], [2, 1, 3], [2, 3, 1]]
list(Permutations(3))
list(Permutations(3))[1:4]
import itertools
list(itertools.islice(Permutations(3), 1r, 4r))

Note that all calls to islice must have arguments of type int and not Sage integers.

The behaviour of the functions map() and filter() has changed between Python 2 and Python 3. In Python 3, they return an iterator. If you want to return a list like in Python 2 you need to explicitly wrap them in list():

sage: list(map(lambda z: z.cycle_type(), Permutations(3)))
[[1, 1, 1], [2, 1], [2, 1], [3], [3], [2, 1]]

sage: list(filter(lambda z: z.has_pattern([1,2]), Permutations(3)))
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2]]
>>> from sage.all import *
>>> list(map(lambda z: z.cycle_type(), Permutations(Integer(3))))
[[1, 1, 1], [2, 1], [2, 1], [3], [3], [2, 1]]

>>> list(filter(lambda z: z.has_pattern([Integer(1),Integer(2)]), Permutations(Integer(3))))
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2]]
list(map(lambda z: z.cycle_type(), Permutations(3)))
list(filter(lambda z: z.has_pattern([1,2]), Permutations(3)))

Defining new iterators

One can very easily write new iterators using the keyword yield. The following function does nothing interesting beyond demonstrating the use of yield:

sage: def f(n):
....:   for i in range(n):
....:       yield i
sage: [ u for u in f(5) ]
[0, 1, 2, 3, 4]
>>> from sage.all import *
>>> def f(n):
...   for i in range(n):
...       yield i
>>> [ u for u in f(Integer(5)) ]
[0, 1, 2, 3, 4]
def f(n):
  for i in range(n):
      yield i
[ u for u in f(5) ]

Iterators can be recursive:

sage: def words(alphabet,l):
....:    if l == 0:
....:        yield []
....:    else:
....:        for word in words(alphabet, l-1):
....:            for a in alphabet:
....:                yield word + [a]

sage: [ w for w in words(['a','b','c'], 3) ]
[['a', 'a', 'a'], ['a', 'a', 'b'], ['a', 'a', 'c'], ['a', 'b', 'a'], ['a', 'b', 'b'], ['a', 'b', 'c'], ['a', 'c', 'a'], ['a', 'c', 'b'], ['a', 'c', 'c'], ['b', 'a', 'a'], ['b', 'a', 'b'], ['b', 'a', 'c'], ['b', 'b', 'a'], ['b', 'b', 'b'], ['b', 'b', 'c'], ['b', 'c', 'a'], ['b', 'c', 'b'], ['b', 'c', 'c'], ['c', 'a', 'a'], ['c', 'a', 'b'], ['c', 'a', 'c'], ['c', 'b', 'a'], ['c', 'b', 'b'], ['c', 'b', 'c'], ['c', 'c', 'a'], ['c', 'c', 'b'], ['c', 'c', 'c']]
sage: sum(1 for w in words(['a','b','c'], 3))
27
>>> from sage.all import *
>>> def words(alphabet,l):
...    if l == Integer(0):
...        yield []
...    else:
...        for word in words(alphabet, l-Integer(1)):
...            for a in alphabet:
...                yield word + [a]

>>> [ w for w in words(['a','b','c'], Integer(3)) ]
[['a', 'a', 'a'], ['a', 'a', 'b'], ['a', 'a', 'c'], ['a', 'b', 'a'], ['a', 'b', 'b'], ['a', 'b', 'c'], ['a', 'c', 'a'], ['a', 'c', 'b'], ['a', 'c', 'c'], ['b', 'a', 'a'], ['b', 'a', 'b'], ['b', 'a', 'c'], ['b', 'b', 'a'], ['b', 'b', 'b'], ['b', 'b', 'c'], ['b', 'c', 'a'], ['b', 'c', 'b'], ['b', 'c', 'c'], ['c', 'a', 'a'], ['c', 'a', 'b'], ['c', 'a', 'c'], ['c', 'b', 'a'], ['c', 'b', 'b'], ['c', 'b', 'c'], ['c', 'c', 'a'], ['c', 'c', 'b'], ['c', 'c', 'c']]
>>> sum(Integer(1) for w in words(['a','b','c'], Integer(3)))
27
def words(alphabet,l):
   if l == 0:
       yield []
   else:
       for word in words(alphabet, l-1):
           for a in alphabet:
               yield word + [a]
[ w for w in words(['a','b','c'], 3) ]
sum(1 for w in words(['a','b','c'], 3))

Here is another recursive iterator:

sage: def dyck_words(l):
....:     if l==0:
....:         yield ''
....:     else:
....:         for k in range(l):
....:             for w1 in dyck_words(k):
....:                 for w2 in dyck_words(l-k-1):
....:                     yield '('+w1+')'+w2

sage: list(dyck_words(4))
['()()()()',
'()()(())',
'()(())()',
'()(()())',
'()((()))',
'(())()()',
'(())(())',
'(()())()',
'((()))()',
'(()()())',
'(()(()))',
'((())())',
'((()()))',
'(((())))']

sage: sum(1 for w in dyck_words(5))
42
>>> from sage.all import *
>>> def dyck_words(l):
...     if l==Integer(0):
...         yield ''
...     else:
...         for k in range(l):
...             for w1 in dyck_words(k):
...                 for w2 in dyck_words(l-k-Integer(1)):
...                     yield '('+w1+')'+w2

>>> list(dyck_words(Integer(4)))
['()()()()',
'()()(())',
'()(())()',
'()(()())',
'()((()))',
'(())()()',
'(())(())',
'(()())()',
'((()))()',
'(()()())',
'(()(()))',
'((())())',
'((()()))',
'(((())))']

>>> sum(Integer(1) for w in dyck_words(Integer(5)))
42
def dyck_words(l):
    if l==0:
        yield ''
    else:
        for k in range(l):
            for w1 in dyck_words(k):
                for w2 in dyck_words(l-k-1):
                    yield '('+w1+')'+w2
list(dyck_words(4))
sum(1 for w in dyck_words(5))

Standard Iterables

Finally, many standard Python and Sage objects are iterable; that is one may iterate through their elements:

sage: sum( x^len(s) for s in Subsets(8) )
x^8 + 8*x^7 + 28*x^6 + 56*x^5 + 70*x^4 + 56*x^3 + 28*x^2 + 8*x + 1

sage: sum( x^p.length() for p in Permutations(3) )
x^3 + 2*x^2 + 2*x + 1

sage: factor(sum( x^p.length() for p in Permutations(3) ))
(x^2 + x + 1)*(x + 1)

sage: P = Permutations(5)
sage: all( p in P for p in P )
True

sage: for p in GL(2, 2): print(p); print("")
[1 0]
[0 1]

[0 1]
[1 0]

[0 1]
[1 1]

[1 1]
[0 1]

[1 1]
[1 0]

[1 0]
[1 1]


sage: for p in Partitions(3): print(p)
[3]
[2, 1]
[1, 1, 1]
>>> from sage.all import *
>>> sum( x**len(s) for s in Subsets(Integer(8)) )
x^8 + 8*x^7 + 28*x^6 + 56*x^5 + 70*x^4 + 56*x^3 + 28*x^2 + 8*x + 1

>>> sum( x**p.length() for p in Permutations(Integer(3)) )
x^3 + 2*x^2 + 2*x + 1

>>> factor(sum( x**p.length() for p in Permutations(Integer(3)) ))
(x^2 + x + 1)*(x + 1)

>>> P = Permutations(Integer(5))
>>> all( p in P for p in P )
True

>>> for p in GL(Integer(2), Integer(2)): print(p); print("")
[1 0]
[0 1]
<BLANKLINE>
[0 1]
[1 0]
<BLANKLINE>
[0 1]
[1 1]
<BLANKLINE>
[1 1]
[0 1]
<BLANKLINE>
[1 1]
[1 0]
<BLANKLINE>
[1 0]
[1 1]
<BLANKLINE>

>>> for p in Partitions(Integer(3)): print(p)
[3]
[2, 1]
[1, 1, 1]
sum( x^len(s) for s in Subsets(8) )
sum( x^p.length() for p in Permutations(3) )
factor(sum( x^p.length() for p in Permutations(3) ))
P = Permutations(5)
all( p in P for p in P )
for p in GL(2, 2): print(p); print("")
for p in Partitions(3): print(p)

Beware of infinite loops:

sage: for p in Partitions(): print(p)          # not tested
>>> from sage.all import *
>>> for p in Partitions(): print(p)          # not tested
for p in Partitions(): print(p)          # not tested
sage: for p in Primes(): print(p)              # not tested
>>> from sage.all import *
>>> for p in Primes(): print(p)              # not tested
for p in Primes(): print(p)              # not tested

Infinite loops can nevertheless be very useful:

sage: exists( Primes(), lambda p: not is_prime(mersenne(p)) )
(True, 11)


sage: counter_examples = (p for p in Primes() if not is_prime(mersenne(p)))
sage: next(counter_examples)
11
>>> from sage.all import *
>>> exists( Primes(), lambda p: not is_prime(mersenne(p)) )
(True, 11)


>>> counter_examples = (p for p in Primes() if not is_prime(mersenne(p)))
>>> next(counter_examples)
11
exists( Primes(), lambda p: not is_prime(mersenne(p)) )
counter_examples = (p for p in Primes() if not is_prime(mersenne(p)))
next(counter_examples)