17 Jul 2009

Flatten for Python

After making frequent use of lists:flatten/1 in Erlang, I found myself wanting to use this function in Python too. Somewhat surprisingly, Python doesn't have a builtin for flattening lists.

If you're not familiar with flatten, its most easily explained through some quick examples:

>>> flattened([1,2,3,4])
[1,2,3,4]
>>> flattened([[1,2],[3,4]])
[1,2,3,4]
>>> flattened([[[[1],2],3],4]])
[1,2,3,4]

Hopefully you get the idea, it takes nested lists and places all the items they contain in order into a single list. My initial hack to get this working was the following one-liner:

def flattened(l):
    return reduce(lambda x,y: x+[y] if type(y) != list else x+flatten(y), l,[])

This was fine in the scenario I'd written it for, but I knew full-well it was going to explode with recursion errors if it ever encountered a long, complicated list. Unlike Erlang, Python doesn't do nice tail recursion, and will fall over when its call stack fills up, reporting a RuntimeError.

Recently, I came across this snippet of code again and decided to solve the problem properly. After some Googling around I found this post by Danny Yoo on a Continuation Passing Style version of flatten. This got the geek in me all interested, and I decided to see if I could write something in the form of a Python generator instead, while still avoiding the maximum recursion depth issue. Here is the resulting code (seriously nasty code coming up!):

def flattened(l):
    result = _flatten(l, lambda x: x)
    while type(result) == list and len(result) and callable(result[0]):
        if result[1] != []:
            yield result[1]
        result = result[0]([])
    yield result

def _flatten(l, fn, val=[]):
    if type(l) != list:
        return fn(l)
    if len(l) == 0:
        return fn(val)
     return [lambda x: _flatten(l[0], lambda y: _flatten(l[1:],fn,y), x), val]

Apart from making my brain hurt, what really amazed me about this approach was the speed increases. On a list nested to a depth of 100, this turned out to be around 2 times faster than the the CPS version by Danny Yoo. However, this is not a linear relationship, at a depth of 20,000 it was completing around 12 times faster (and yes, thats generating the whole list, not the time taken to hit the first yield). I didn't expect there to be such a difference, but I'm now interested in seeing where else generators and this kind of pattern can be used!

Note: I decided to rename the function 'flattened', because of its similarity to the Python builtin 'reversed' which returns a new generator instead of editing the list in place. This function also only flattens lists, not tuples (intentionally), if you wanted it to flatten tuples it should only require some slight tweaking (which I'll leave to you).