2020-12-08
< view all posts首先说一个小知识点:在Scala中,递归函数的返回值类型是必须指定的(而其它函数则不必,当然从代码风格的角度,全都写明会更加清楚)。因为编译器会尝试从函数的右侧表达式找这个函数的返回值,但是对递归函数,它会找到这个函数本身,导致返回值无法确定。
尾递归:如果一个函数的递归调用仅仅是调用其本身,那么就是一个尾递归。Scala的编译器会对尾递归进行优化,使其在执行时只占用恒量的内存,因此可以达到和循环一致的执行效率。
类似的,尾调用的定义是,如果一个函数最后的动作仅仅是对(本身或另一个)函数的调用,那么调用栈就不会动态增大。
这里值得一提的是,并非所有语言都对尾递归、尾调用有优化,例如Python(CPython)就不提供原生的尾递归优化。关于具体的理由可以参见 Final Words on Tail Calls
而在Scala这样的语言中,尾递归的执行效率会比非尾递归更高。一个非尾递归的例子,阶乘:
def factorial(n: Int): Int = if (n == 0) 1 else n * factorial(n - 1)
改写成尾递归:
def factorial(n: Int): Int ={ def doFactorial(n: Int, acc: Int): Int = if (n == 0) acc else doFactorial(n - 1, n * acc) doFactorial(n, 1) }
思路是用参数来传递计算的中间结果,而不是将表达式保存在调用栈中。这里我们把递归函数外面又包装了一层函数,是为了对参数做内部初始化,不把它暴露在外面。在递归中需要初始化时,这是一种比较好用的写法。
还可以用非递归的循环写法来写,比较一下:
def factorial(n: Int): Int = { var acc = 1 var i = 1 while (i <= n) { acc = acc * i i += 1 } acc }
可以感受到,函数式(递归)的写法,相当于定义了如下的数学函数:
f(n, acc) = f(n * acc, n - 1) (when n > 0) = acc (when n = 0)
而程序的执行过程,实际上是对这个函数层层代入进行求值的过程。也就是所谓的代换模型(substitution model)
而上面循环的例子是典型的指令式写法,用变量来存储数据,用流程控制来规定程序的执行逻辑:
f(n) = ( (((1 * 2) * 3) * 4) * ...*(n-1)) * n
从直观的感受来讲,这种写法的抽象程度没有递归写法的高。
不过是不是递归写法就一定要求更高的抽象程度呢,其实也不一定。上面提到,写尾递归的时候,我们用函数的参数代替了变量对中间值的存储。其实这个思路可以用来把循环“暴力”改写成递归:
def factorial(n: Int): Int ={ def doFactorial(i: Int, acc: Int): Int = if (i > n) acc else doFactorial(i + 1, acc * i) doFactorial(1, 1) }
实际上做的操作就是把原来定义的一个循环变量i,变成了一个随着递归不断传递下去的函数参数。和前一个尾递归的写法相比,形式上只是把从后往前算阶乘变成了从前往后算,不过在思路上却有比较大的区别。另外,这也是对递归和循环等价性的一种佐证。
*写递归的两种思路*
可以总结出写递归是有两种思路的。一种是首先考虑最外层需要做什么,之后使用相同的逻辑,把剩下的部分再当作输入。例如统计一个n长度的字符串中某个字符的个数,那么很明显,最外层只需要判断第一个字符,之后还剩下n-1长度,就再把这剩下的部分当作已经写好的最外层逻辑的输入。
再讲清楚一点,这里所说的“最外层”,就是指递归在得到结果前的最后一次执行,即“临门一脚”。而这之前的执行结果,都可以当作已经处理完成,用调用同一个函数的方式写出。在统计字符个数的例子中,最外层就是前n-1个字符都处理完毕之后,对剩下的最后一个字符的处理。把对前n-1个字符的处理结果,用 count(x.init, target) 表示,那么在“最后一步”,我们要做的就是判断 x.last,并且把结果加上去:
def count(s: String, target: Char): Int = { if (s.length == 0) 0 // 终止条件 else if (s.last.equals(target)) // 判断字符串的最后一个字符,即s.last 1 + count(s.init, target) // 加到对剩余字符,即s.init的判断结果当中去 else count(s.init, target) }
这种思路写出来的递归,很有可能是尾递归,是对同一段逻辑的重复利用。这种递归在工程开发中用处是比较大的,比如对字符串、文件之类的处理上面,用到比较多。它的好处是,比起在循环里使用循环变量(循环变量的边界问题常常比较麻烦),而递归里只考虑单次的逻辑+停止条件,思路上常常会更加清晰,写起来也更顺手一些。当然,这类递归要改成循环也是比较方便的。
那么另一种思路,就是函数式的递归。用指令式的思路去看这种递归,第一眼常常会觉得有点神奇。这篇文章里有一个很形象的说法“写递归的要点:明白一个函数的作用并相信它能完成这个任务,千万不要跳进这个函数里面企图探究更多细节”。这是一种很直觉的理解,也是很正确的。跳进调用栈里去尝试理解这类递归是很难的。
不过比起单纯的“相信它能完成这个任务”,我们可以抓住这类递归的原理,用类似于推导数学公式的方法来写这类递归。其实过程有点像写数列的通项公式,需要我们用代换模型去把需要求解的项目不断向已经定义的逻辑作替换。这样说比较抽象,看一个具体的例子:
问题是给一个整数数组和一个整数,求数组元素相加得到给定整数的组合一共有多少种。用递归的方式来写。
我们把这个递归函数定义为 solve() ,例如对它的一个调用是 solve([1,2,3], 17) 。接下来,要用已经定义的solve()去对这个调用作代换,只要不断代换下去直到停止条件,那么这个调用就能求解了。
如何代换呢,用简单的思路,把所有可能性遍历一遍。首先不用1,只用2,3去组合,那么可以写成 solve([2,3],17) ;下一种情况,用一次1,之后用2,3去组合16,即 solve([2,3],16);再下一种情况用两次1……那么就可以写出:
solve([1,2,3], 17) = solve([2,3],17) + solve([2,3],16) + ... + solve([2,3],0)
这里我们已经考虑了一个停止条件,也就是 solve() 的第二个参数小于0的时候,停止递归,并且返回0(减到负数,说明组合不能成立);当它等于0的时候,停止递归,并且返回1(刚好减到0,说明组合成立)。
到这一步我们得到了一个代换后的式子,但是实现这个式子需要一个循环的逻辑,因为项数是不确定的。当然可以直接用循环写,也可以更进一步,只用递归——观察从第二项起的后面每一项,以及我们已经写出的代换逻辑,很容易发现从第二项加到最后一项,其实就等于solve([1,2,3],16):
solve([2,3],16) + ... + solve([2,3],0) = solve([1,2,3],16)
因此原式就可以进一步写成:
solve([1,2,3], 17) = solve([2,3],17) + solve([1,2,3],16)
OK,到这里我们就成功找到这个问题的“递推公式”了。那么还剩下最后一个问题,就是另一个和数组相关的终止条件。其实从直觉上很容易判断,数组为空的时候 solve() 返回0就可以,也可以选择一个式子,比如 solve([3],17),来看一下具体的情况:
solve([3],17) = solve([],17) + solve([3],14) = solve([],17) + solve([],14) + solve([3],11) = ...
这个式子最终会因为代换到... solve([],2) + solve([3],-1) 而停止,它的最终结果只受最后一项影响,所以前面所有参数为空数组[]的调用,全部返回0即可。
归纳一下上面的分析,把它写成Scala代码:
def solve(money: Int, coins: List[Int]): Int = { if (money == 0) 1 else if (coins.isEmpty || money < 0) 0 else solve(money, coins.tail) + solve(money - coins.head, coins) }
只需要三行代码就能实现。而在这后面作为支撑的,就是函数式编程的代换模型,和数学思维。