这个学期学了一门函数式语言 Coq ,是法国人发明的,它与 ML 很像,其实 Coq 的发明主要是受了 ML 的影响。Coq 的语法我就不在这叙述了,有兴趣的朋友可以看这本书 Software Foundations ,也就是我们上课用的教材。下面直接看些例子吧。
1. map
它的功能与 C++ 中的std::transform
算法类似,是将f
作用在list
中的每个元素(元素类型是X
)上,返回一个list
(元素类型是Y
),其中f
是一元函数,形参类型是X
,返回值类型是Y
。有一点得注意,这里的map
和list
与 STL 中的容器无关,因为 Coq 的源程序中是这么命名的,所以我就沿用了。
Fixpoint map { X Y : Type } ( f : X -> Y ) ( l : list X )
: ( list Y ) :=
match l with
| [] => []
| h :: t => ( f h ) :: ( map f t )
end .
我突发奇想,用 C++ 实现了上面的功能。程序如下。
template < class X , class UnaryOperation , template < class , class ... > class List >
auto map ( UnaryOperation f , List < X > l )
{
if ( l . empty ()) {
return List < decltype ( f ( X ())) > ();
}
X h = l . front ();
l . pop_front ();
auto ret = map ( f , l );
ret . push_front ( f ( h ));
return ret ;
}
在写这个函数时,我最大的感受是函数式语言的自动类型推导 特别强大,而用 C++ 改写时就得使用泛型技术实现,上面的代码用了 C++14 标准。
2. fold
下面再实现一个与map
功能相反的函数fold
,它的功能是将list
中的元素归约,与 STL 中的std::accumulate
类似。Coq 的代码如下。
Fixpoint fold { X Y : Type } ( f : X -> Y -> Y ) ( l : list X ) ( b : Y ) : Y :=
match l with
| nil => b
| h :: t => f h ( fold f t b )
end .
相应地,用 C++ 改写后程序如下。
template < class Y , class BinaryOperation , class List >
Y fold ( BinaryOperation f , List l , Y b )
{
if ( l . empty ()) {
return b ;
}
auto h = l . front ();
l . pop_front ();
return f ( h , fold ( f , l , b ));
}
3. fold_map
在学 Coq 时,这个函数是个练习题,要求用fold
实现map
。我刚开始看到这个题目时觉得挺神奇,fold
是“变小”,而map
“不变”,怎么能用fold
实现map
呢?其实关键在于fold
中的那个f
,它的类型可以看成是X->Y->Y
,这里的X
和Y
可以不同。比如X
可以是自然数,Y
可以是list
,这就“变大”了。实现如下。
Definition fold_map { X Y : Type } ( f : X -> Y ) ( l : list X ) : list Y :=
fold ( fun x l => ( f x ) :: l ) l [] .
相应的 C++ 代码如下。
/*
* use fold to implement map
*/
template < class X , class UnaryOperation , template < class , class ... > class List >
auto fold_map ( UnaryOperation f , List < X > l )
{
using L = List < decltype ( f ( X ())) > ;
auto bin_fun = [ & f ] ( X x , L lst ) {
lst . push_front ( f ( x ));
return lst ;
};
return fold ( bin_fun , l , L ());
}
在函数式语言中使用 Lambda 表达式是很常见的,并且捕捉列表都不用指明,语法简单,不过它内部的实现工作量大,所以效率就没 C++ 高。
4. 用 STL 实现 fold 和 map
上面说了,map
和fold
的功能分别与std::transform
和std::accumulate
功能类似。所以也可以用它们来实现,程序如下。
template < class Y , class BinaryOperation , class List >
Y fold ( BinaryOperation f , List l , Y b )
{
// reverse parameter list in f
auto binary_op = [ & f ] ( auto a , auto b ) {
return f ( b , a );
};
return accumulate ( crbegin ( l ), crend ( l ), b , binary_op );
}
template < class X , class UnaryOperation , template < class , class ... > class List >
auto map ( UnaryOperation f , List < X > l )
{
List < decltype ( f ( X ())) > dst_list ( l . size ());
transform ( cbegin ( l ), cend ( l ), begin ( dst_list ), f );
return dst_list ;
}
其实fold
与std::accumulate
之间有两个细微的差别。第一个是fold
的形参f
的类型是X->Y->Y
,而std::accumulate
中的f
的类型规定为Y->X->Y
,所以实现时得调换一下参数顺序。要是不明白的话直接看std::accumulate
的实现吧:
template < class InputIterator , class Tp , class BinaryOperation >
Tp accumulate ( InputIterator first , InputIterator last , Tp init , BinaryOperation binary_op )
{
for ( ; first != last ; ++ first ) {
init = binary_op ( init , * first );
}
return init ;
}
第二个差别是std::accumulate
是从前往后执行,而fold
是从后往前执行,为什么fold
是从后往前呢?这个与list
的定义有关:
Inductive list ( X : Type ) : Type :=
| nil : list X
| cons : X -> list X -> list X .
Arguments cons { X } _ _ . (* use underscore for argument position that has no name *)
Notation "x :: y" := ( cons x y )
( at level 60 , right associativity ) .
上面的cons
的类型是forall X : Type, X -> list X -> list X
,意思是第一个参数的类型是X
(其实真正的第一个参数的类型是Type
,而它类似于 C++ 中的模板参数,可自动推导,所以可以省略),第二个参数的类型是list X
,那么结果的类型是list X
,这是根本原因,导致了fold
从后往前执行。用std::accumulate
实现时可以用反向迭代器解决。
5. 参考资料
[1] Software Foundations
[2] C++ Primer
[3] STL 源码剖析
6. 附录(源程序)
#include <iostream>
#include <list>
#include <iterator>
#include <algorithm>
#include <string>
#include <cassert>
using namespace std ;
template < class Y , class BinaryOperation , class List >
Y fold ( BinaryOperation f , List l , Y b )
{
if ( l . empty ()) {
return b ;
}
auto h = l . front ();
l . pop_front ();
return f ( h , fold ( f , l , b ));
}
template < class X , class UnaryOperation , template < class , class ... > class List >
auto map ( UnaryOperation f , List < X > l )
{
if ( l . empty ()) {
return List < decltype ( f ( X ())) > ();
}
X h = l . front ();
l . pop_front ();
auto ret = map ( f , l );
ret . push_front ( f ( h ));
return ret ;
}
/*
* use fold to implement map
*/
template < class X , class UnaryOperation , template < class , class ... > class List >
auto fold_map ( UnaryOperation f , List < X > l )
{
using L = List < decltype ( f ( X ())) > ;
auto bin_fun = [ & f ] ( X x , L lst ) {
lst . push_front ( f ( x ));
return lst ;
};
return fold ( bin_fun , l , L ());
}
int main ( int argc , char ** argv )
{
auto f = [] ( auto x ) {
return "element is " + to_string ( x );
};
list < short > l = { 1 , 2 };
auto lst = fold_map ( f , l );
copy ( lst . begin (), lst . end (), ostream_iterator < decltype ( lst ) :: value_type > ( cout , " \n " ));
assert ( fold_map ( f , l ) == map ( f , l ));
return 0 ;
}
#include <iostream>
#include <numeric>
#include <deque>
#include <iterator>
#include <algorithm>
#include <string>
#include <cassert>
using namespace std ;
template < class Y , class BinaryOperation , class List >
Y fold ( BinaryOperation f , List l , Y b )
{
// reverse parameter list in f
auto binary_op = [ & f ] ( auto a , auto b ) {
return f ( b , a );
};
return accumulate ( crbegin ( l ), crend ( l ), b , binary_op );
}
template < class X , class UnaryOperation , template < class , class ... > class List >
auto map ( UnaryOperation f , List < X > l )
{
List < decltype ( f ( X ())) > dst_list ( l . size ());
transform ( cbegin ( l ), cend ( l ), begin ( dst_list ), f );
return dst_list ;
}
/*
* use fold to implement map
*/
template < class X , class UnaryOperation , template < class , class ... > class List >
auto fold_map ( UnaryOperation f , List < X > l )
{
using L = List < decltype ( f ( X ())) > ;
auto bin_fun = [ & f ] ( X x , L lst ) {
lst . push_front ( f ( x ));
return lst ;
};
return fold ( bin_fun , l , L ());
}
int main ( int argc , char ** argv )
{
auto f = [] ( auto x ) {
return "element is " + to_string ( x );
};
deque < short > l = { 1 , 2 };
auto lst = fold_map ( f , l );
copy ( cbegin ( lst ), cend ( lst ), ostream_iterator < decltype ( lst ) :: value_type > ( cout , " \n " ));
assert ( fold_map ( f , l ) == map ( f , l ));
return 0 ;
}
(End)