Py尾递归以及优化

lee Alexander <[email protected]>
sender-time     Sent at 01:57 (GMT+08:00). Current time there: 8:53 AM. ✆
reply-to        [email protected]
to      [email protected]
date    Thu, Sep 16, 2010 at 01:57
subject [CPyUG]一个很Cool的Idear->Python的尾递归优化

参考

activestate提议

偶然在国外一个网站瞅到的,非常的酷,发出来共享一下。一般来说,Python和Java,C#一样是没有尾递归自动优化的能力的,递归调用受到调用栈长度的限制被广泛的诟病,但是这个狂人用一个匪夷所思的方法解决了这个问题并在Python上实现了,从此Python的递归调用再也不用受到调用栈长度的制约,太酷了。

   1 def Fib(n,b1=1,b2=1,c=3):
   2     if n<3:
   3         return 1
   4     else:
   5         if n==c:
   6             return b1+b2
   7         else:
   8             return Fib(n,b1=b2,b2=b1+b2,c=c+1)

这段程序我们来测试一下,调用 Fib(1001)结果:

>>> def Fib(n,b1=1,b2=1,c=3):

...     if n<3:

...         return 1

...     else:

...         if n==c:

...             return b1+b2

...         else:

...             return Fib(n,b1=b2,b2=b1+b2,c=c+1)

... 

>>> Fib(1001)

70330367711422815821835254877183549770181269836358732742604905087154537118196933
57974224949456261173348775044924176599108818636326545022364710601205337412127386
7339111198139373125598767690091902245245323403501L

如果我们用Fib(1002),结果,茶几了,如下:

  .....

  File "<stdin>", line 8, in Fib

  File "<stdin>", line 8, in Fib

  File "<stdin>", line 8, in Fib

  File "<stdin>", line 8, in Fib

  File "<stdin>", line 8, in Fib

  File "<stdin>", line 8, in Fib

RuntimeError: maximum recursion depth exceeded

>>> 

好了,现在我们来尾递归优化

我们给刚才的Fib函数增加一个Decorator,如下:

   1 @tail_call_optimized
   2 def Fib(n,b1=1,b2=1,c=3):
   3     if n<3:
   4         return 1
   5     else:
   6         if n==c:
   7             return b1+b2
   8      else:
   9         return Fib(n,b1=b2,b2=b1+b2,c=c+1)

@tail_call_optimized

恩,就是这个@tail_call_optimized的装饰器 ,这个装饰器使Python神奇的打破了调用栈的限制。

不卖关子了,下面我们来看看这段神奇的代码:

   1 # This program shows off a python decorator(
   2 # which implements tail call optimization. It
   3 # does this by throwing an exception if it is 
   4 # it's own grandparent, and catching such 
   5 # exceptions to recall the stack.
   6 
   7 import sys
   8 
   9 class TailRecurseException:
  10   def __init__(self, args, kwargs):
  11     self.args = args
  12     self.kwargs = kwargs
  13 
  14 def tail_call_optimized(g):
  15   """
  16   This function decorates a function with tail call
  17   optimization. It does this by throwing an exception
  18   if it is it's own grandparent, and catching such
  19   exceptions to fake the tail call optimization.
  20   
  21   This function fails if the decorated
  22   function recurses in a non-tail context.
  23   """
  24   def func(*args, **kwargs):
  25     f = sys._getframe()
  26     if f.f_back and f.f_back.f_back \
  27         and f.f_back.f_back.f_code == f.f_code:
  28       raise TailRecurseException(args, kwargs)
  29     else:
  30       while 1:
  31         try:
  32           return g(*args, **kwargs)
  33         except TailRecurseException, e:
  34           args = e.args
  35           kwargs = e.kwargs
  36   func.__doc__ = g.__doc__
  37   return func
  38 
  39 @tail_call_optimized
  40 def factorial(n, acc=1):
  41   "calculate a factorial"
  42   if n == 0:
  43     return acc
  44   return factorial(n-1, n*acc)
  45 
  46 print factorial(10000)
  47 # prints a big, big number,
  48 # but doesn't hit the recursion limit.
  49 
  50 @tail_call_optimized
  51 def fib(i, current = 0, next = 1):
  52   if i == 0:
  53     return current
  54   else:
  55     return fib(i - 1, next, current + next)
  56 
  57 print fib(10000)
  58 # also prints a big number,
  59 # but doesn't hit the recursion limit.
  60 ## end of http://code.activestate.com/recipes/474088/ 

讨论

使用的方法前面已经展示了,令我感到大开眼界的是,作者用了抛出异常然后自己捕获的方式来打破调用栈的增长,简直是太匪夷所思了。而且效率问题,和直接递归大概增加5倍的时间开销,最后很不可思议的,尾递归优化的目的达成了。

机械唯物主义

机械唯物主义 : linjunhalida <[email protected]>
sender-time     Sent at 08:13 (GMT+08:00). Current time there: 9:03 AM. ✆
reply-to        [email protected]
to      [email protected]
date    Thu, Sep 16, 2010 at 08:13

如果为了效率, 还是转迭代吧. 这个是比较类似递归的迭代.

   1 def process(b1, b2, c):
   2     t = b2
   3     b2 += b1
   4     b1 = t
   5     c += 1
   6     return b1, b2, c
   7 
   8 def fib(n):
   9     b1, b2, c = 1, 1, 3
  10     while 1:
  11         if n<3:
  12             return 1
  13         elif n==c:
  14             return b1+b2
  15         else:
  16             b1, b2, c = process(b1, b2, c)

Shell Xu 实測

Shell Xu <[email protected]>
sender-time     Sent at 10:34 (GMT+08:00). Current time there: 11:28 AM. ✆
reply-to        [email protected]
to      [email protected]
date    Thu, Sep 16, 2010 at 10:34
subject Re: [CPyUG]一个很Cool的Idear->Python的尾递归优化

   1 # -*- coding: utf-8 -*-
   2 # @date: 2010-09-16
   3 # @author: shell.xu
   4 import os
   5 import sys
   6 
   7 class TailRecurseException:  
   8     def __init__(self, args, kwargs):  
   9         self.args = args  
  10         self.kwargs = kwargs  
  11         
  12 def tail_call_optimized(g):  
  13     def func(*args, **kwargs):  
  14         f = sys._getframe()  
  15         if f.f_back and f.f_back.f_back and\
  16                 f.f_back.f_back.f_code == f.f_code:  
  17             raise TailRecurseException(args, kwargs)  
  18         else:
  19             while 1:  
  20                 try: return g(*args, **kwargs)  
  21                 except TailRecurseException, e:  
  22                     args, kwargs = e.args, e.kwargs
  23     return func
  24 
  25 class InnerDt(object):
  26     def __init__(self, params, kargs):
  27         self.params, self.kargs = params, kargs
  28 
  29 def tail_call(func):
  30     def inner(*params, **kargs):
  31         f = sys._getframe()  
  32         if f.f_back and f.f_back.f_back and\
  33                 f.f_back.f_back.f_code == f.f_code:
  34             return InnerDt(params, kargs)
  35         else:
  36             while 1:
  37                 r = func(*params, **kargs)
  38                 if not isinstance(r, InnerDt): return r
  39                 else: params, kargs = r.params, r.kargs
  40     return inner
  41 
  42 @tail_call_optimized
  43 def Fib(n,b1=1,b2=1,c=3):
  44     if n<3: return 1
  45     if n==c: return b1+b2
  46     return Fib(n,b1=b2,b2=b1+b2,c=c+1)
  47 
  48 @tail_call
  49 def Fibr(n,b1=1,b2=1,c=3):
  50     if n<3: return 1
  51     if n==c: return b1+b2
  52     return Fibr(n,b1=b2,b2=b1+b2,c=c+1)
  53 
  54 def process(b1, b2, c):
  55     return b2, b1 + b2, c + 1
  56 
  57 def fib1(n):
  58     b1, b2, c = 1, 1, 3
  59     while 1:
  60         if n<3: return 1
  61         elif n==c: return b1+b2
  62         else: b1, b2, c = process(b1, b2, c)
  63 
  64 def fib2(n):
  65     if n < 3: return 1
  66     f1, f2 = 1, 1
  67     for i in xrange(3, n):
  68         if i & 1: f1 += f2
  69         else: f2 += f1
  70     return f1 + f2
  71 
  72 if __name__=='__main__':
  73     from timeit import Timer
  74     t = Timer("Fib(10000)", "from __main__ import *")
  75     tr = Timer("Fibr(10000)", "from __main__ import *")
  76     t1 = Timer("fib1(10000)", "from __main__ import *")
  77     t2 = Timer("fib2(10000)", "from __main__ import *")
  78     print t.timeit(10), tr.timeit(10), t1.timeit(10), t2.timeit(10)
  79     print Fib(10000) == Fibr(10000), fib1(10000) == fib2(10000), Fib(10000) == fib1(10000)

---------------------------------------------------------------------------

c:\Documents and Settings\Administrator\My Documents\note>test.py
<timeit-src>:2: SyntaxWarning: import * only allowed at module level
<timeit-src>:2: SyntaxWarning: import * only allowed at module level
<timeit-src>:2: SyntaxWarning: import * only allowed at module level
<timeit-src>:2: SyntaxWarning: import * only allowed at module level
1.21529047311 4.59827187267 0.325211272378 0.279499038127
True True True

---------------------------------------------------------------------------

Fib(20000)
python -m profile test.py

---------------------------------------------------------------------------

         80000 function calls (60003 primitive calls) in 0.681 CPU seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    19998    0.059    0.000    0.059    0.000 :0(_getframe)
        1   -0.001   -0.001    0.674    0.674 :0(execfile)
        1    0.004    0.004    0.004    0.004 :0(setprofile)
        1    0.003    0.003    0.677    0.677 <string>:1(<module>)
        1    0.000    0.000    0.681    0.681 profile:0(execfile('test.py'))
        0    0.000             0.000          profile:0(profiler)
        1    0.000    0.000    0.000    0.000 test.py:13(tail_call_optimized)
  19998/1    0.337    0.000    0.674    0.674 test.py:14(func)
        1    0.000    0.000    0.000    0.000 test.py:26(InnerDt)
        1    0.000    0.000    0.000    0.000 test.py:30(tail_call)
    19998    0.245    0.000    0.585    0.000 test.py:43(Fib)
        1    0.000    0.000    0.675    0.675 test.py:5(<module>)
        1    0.000    0.000    0.000    0.000 test.py:8(TailRecurseException)
    19997    0.033    0.000    0.033    0.000 test.py:9(__init__)

---------------------------------------------------------------------------

Fibr(20000)
python -m profile test.py

---------------------------------------------------------------------------

         99998 function calls (80001 primitive calls) in 0.677 CPU seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    19998    0.038    0.000    0.038    0.000 :0(_getframe)
        1    0.002    0.002    0.673    0.673 :0(execfile)
    19998    0.065    0.000    0.065    0.000 :0(isinstance)
        1    0.004    0.004    0.004    0.004 :0(setprofile)
        1    0.000    0.000    0.673    0.673 <string>:1(<module>)
        1    0.000    0.000    0.677    0.677 profile:0(execfile('test.py'))
        0    0.000             0.000          profile:0(profiler)
        1    0.000    0.000    0.000    0.000 test.py:13(tail_call_optimized)
        1    0.000    0.000    0.000    0.000 test.py:26(InnerDt)
    19997    0.076    0.000    0.076    0.000 test.py:27(__init__)
        1    0.000    0.000    0.000    0.000 test.py:30(tail_call)
  19998/1    0.247    0.000    0.671    0.671 test.py:31(inner)
    19998    0.244    0.000    0.539    0.000 test.py:49(Fibr)
        1    0.000    0.000    0.671    0.671 test.py:5(<module>)
        1    0.000    0.000    0.000    0.000 test.py:8(TailRecurseException)

为什么?timeit的测量结果和profile差这么多?


反馈

创建 by -- ZoomQuiet [2010-09-16 01:04:42]

MiscItems/2010-09-16 (last edited 2010-09-16 03:30:18 by ZoomQuiet)