Y Combinator in Python

I just finished reading The Little Schemer. What an awesome book! I love the Q&A format of this book and can't wait to read the next book in the series, The Seasoned Schemer.

Chapter 9 and 10 (the last two chapters) were really good. Chapter 10 is about writing a small scheme interpreter. But its Chapter 9 that I want to discuss in this blog post because Chapter 9 introduces the Y Combinator.

I've tried to understand it in the past and failed (the wikipedia page is completely unintelligible) , but this book explains it in a brilliant way. In fact it derives the Y combinator from the ground up, making sure each step is easy to understand.

As an exercise, I tried to write it in Python. This is how it came out:

def Y(le):
    def _anon(cc):
        return le(lambda x: cc(cc)(x))
    return _anon(_anon)

If you don't know the Y combinator already, that probably looks quite cryptic.

To give a simplified explanation, the Y combinator is a function that takes one function as input and creates a recursive version as an output. Maybe an example will illustrate the point better:

def _1(factorial):
    def _fn(n):
        if n == 0: return 1
        else:
            return n*factorial(n-1)
    return _fn

Take a look this function. The function name is _1. The function body has something that looks like a recursive factorial implementation, except that it never calls itself (remember, the function name is _1). Instead, it recurses on "factorial" which is a parameter of the function _1.

Basically, the "factorial" parameter is a function, so the meaning of _1 is:

  • If n is zero, return one
  • else, take the factorial parameter, and call that function with parameter n-1, and after that returns, multiply with n and return the value

Exploring a bit more:

def error(n): raise Exception

f = _1(error) # passing function "error" as the parameter
f(0)   # prints 1
f(1)   # Exception

The above function f is passing error as a parameter to _1. So if n is zero, it returns 1. Otherwise, it goes to the else part and calls the function that we passed as parameter, in this case error, which raises an exception.

In order for this to be recursive, we don't want it to call error, we want it to call the same function again. What if we passed the same function as a parameter? So in the else part, instead of calling error, it would call itself. Something like this

f = _1(_1(error))
f(0)   # prints 1
f(1)   # prints 1
f(2)   # Exception

f = _1(_1(_1(_1(error))))
f(0)   # prints 1
f(1)   # prints 1
f(2)   # prints 2
f(3)   # prints 6
f(4)   # Exception

Hmm, in each case, the recursion finally stops with an exception when it encounters the error function. A truly recursive version would be like

f = _1(_1(_1(_1(_1...... forever

Well, that is basically what the Y combinator does. Using some function passing magic, it converts _1 into the forever recursive version. How?? Well thats the magic!! Its a bit complicated to explain here. Get the book and read Chapter 9. But it works! Check this out:

f = Y(_1)
f(0)   # prints 1
f(1)   # prints 1
f(5)   # prints 120
f(10)   # prints 3628800

Amazing isn't it? And the even more amazing thing is that it is not specific to factorial. See this

def _2(length):
    def _fn(alist):
        if not alist: return 0
        else:
            return 1 + length(alist[1:])
    return _fn

f = Y(_2) # calculate length of a list
f([])   # prints 0
f([1,2,3,4,5])   # prints 5

Woohoo!! On a roll now..

def _3(reverse):
    def _fn(alist):
        if not alist: return []
        else:
            return reverse(alist[1:]) + [alist[0]]
    return _fn

f = Y(_3) # reverse a list
f([])   # prints []
f([1,2,3])   # prints [3,2,1]

You can take any recursive function, and rewrite it in the above style and the Y combinator will make a recursive version of it. How cool is that?