罗龙
机器学习需要训练大量数据,涉及大量复杂运算,如卷积、矩阵等。这些复杂的运算不仅数量众多,而且每次计算的数据量也非常大。如果能优化这些操作,性能可以大大提高。
00-1010:假设是和的矩阵是的矩阵,那么这个矩阵就是和的乘积,叫做矩阵乘积。
矩阵行和列中的元素可以表示为:
如下图所示:
矩阵乘法
如果在矩阵和矩阵中,需要多少次乘法才能完成?
对于每个行向量,总共有行;对于每个列向量,总共有列;计算它们的内积,总共有乘法的次数。可以看出,矩阵乘法的算法复杂度为:
00-1010那么有没有比更快的算法呢?
1969年,Volker Strassen提出第一种算法的时间复杂度低于矩阵乘法算法,算法复杂度为0。从下图可以看出,Strassen算法只能用于维数较大的矩阵。
,性能有很大的优势,可以减少很多乘法运算。
2 x^3对x^2.807
我猜的
Strassen算法证明了矩阵乘法的时间复杂度小于
算法的存在,后续学者不断研究,寻找新的更快的算法。到目前为止,时间复杂度最低的矩阵乘法算法是Coppersmith-Winograd方法的扩展方法,其算法复杂度为。
一、矩阵乘法
假设矩阵总和矩阵
都是
寻找的方阵
,如下图所示:
,
,
在…之中
矩阵
可以通过以下公式找到:
从上面的公式,我们可以得出结论,我们可以计算2
的矩阵乘法需要2。
矩阵的8次乘法和4次加法。我们使用
表达
矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到下面的递归公式:
其中,
表示矩阵相乘的8次,相乘后的矩阵的比例缩小为。2.
表示第4次矩阵加法和合并矩阵的时间复杂度。
的时间复杂性。
最后,可以计算出来。
可以看出,每次递归运算需要8次矩阵乘法,这是瓶颈的来源。与加法相比,矩阵乘法非常慢,所以我们想知道是否可以减少矩阵乘法的次数。
答案当然是!
从这个角度来看,Strassen算法可以降低算法的复杂度!
实施步骤可分为以下四个步骤:
00-1010按照上述方法分解矩阵(需要时间)。2.如下创建10个
矩阵
(慢慢来
)。
3.递归计算7个矩阵乘积。
,每个矩阵
都是
是的。
请注意,只需要计算上面公式中的中间一列。
4.及格
计算
,慢慢来。
下面的递归公式可以通过综合得到:
此外,时间复杂度如下:
00-1010我们从在MNN实现Strassen算法中学到的源代码:https://github.com/Alibaba/MNN/blob/master/source/back end/CPU/compute/strassenmumulcomputer . CPP
类StrassenMatrixComputor提供了三个要调用的API:
_ generaterivalmatmul(co
nst Tensor* AT, const Tensor* BT, const Tensor* CT);普通矩阵乘法计算
_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);Strassen算法的矩阵乘法
_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用)
我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:
第一步:使用Strassen算法收益判断
在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法。
/* Compute the memory read / write cost for expand Matrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write) */ float saveCost = (eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3); if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) { return _generateTrivialMatMul(AT, BT, CT); }第二步:分块
将矩阵
,
,
3个矩阵都分成4块:
auto aStride = AT->stride(0); auto a11 = AT->host<float>() + 0 * aUnit * eSub + 0 * aStride * lSub; auto a12 = AT->host<float>() + 0 * aUnit * eSub + 1 * aStride * lSub; auto a21 = AT->host<float>() + 1 * aUnit * eSub + 0 * aStride * lSub; auto a22 = AT->host<float>() + 1 * aUnit * eSub + 1 * aStride * lSub; auto bStride = BT->stride(0); auto b11 = BT->host<float>() + 0 * bUnit * lSub + 0 * bStride * hSub; auto b12 = BT->host<float>() + 0 * bUnit * lSub + 1 * bStride * hSub; auto b21 = BT->host<float>() + 1 * bUnit * lSub + 0 * bStride * hSub; auto b22 = BT->host<float>() + 1 * bUnit * lSub + 1 * bStride * hSub; auto cStride = CT->stride(0); auto c11 = CT->host<float>() + 0 * aUnit * eSub + 0 * cStride * hSub; auto c12 = CT->host<float>() + 0 * aUnit * eSub + 1 * cStride * hSub; auto c21 = CT->host<float>() + 1 * aUnit * eSub + 0 * cStride * hSub; auto c22 = CT->host<float>() + 1 * aUnit * eSub + 1 * cStride * hSub;第三步:分治和递归
Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:
1. If n = 1 Output A × B 2. Else 3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2 4. P1 Strassen(A11,B12 − B22) 5. P2 Strassen(A11 + A12,B22) 6. P3 Strassen(A21 + A22,B11) 7. P4 Strassen(A22,B21 − B11) 8. P5 Strassen(A11 + A22,B11 + B22) 9. P6 Strassen(A12 − A22,B21 + B22) 10. P7 Strassen(A11 − A21,B11 + B12) 11. C11 P5 + P4 − P2 + P6 12. C12 P1 + P2 13. C21 P3 + P4 14. C22 P1 + P5 − P3 − P7 15. Output C 16. End If例如其中的一步代码如下所示:
{ // S1=A21+A22, T1=B12-B11, P5=S1T1 auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() { MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub); MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub); }; mFunctions.emplace_back(f); auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth); if (code != NO_ERROR) { return code; } }递归执行,得到最终结果!