高维向量相似性检索,有很多时间会花费在欧式距离计算上,做了下测试:
在我的mac上,不论怎么调jvm的SIMD选项(avx,mmx,sse等选项)或者打开enable aggregate opts,300维distance1平均每次耗时为380ns,distance4和distance8均减少到290ns(可见矢量化一定程度生效了)。
而用gcc发现如果不打开优化编译的开关发现性能远远慢于java版本(震惊):
#include <cstdio>#include <ctime>#include <cstdlib>#include <iostream>#include <immintrin.h>#include <x86intrin.h>#define PORTABLE_ALIGN16 __attribute__((aligned(16)))using namespace std;inline float distance8(float* v0, float* v1, int n) { float sum = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0, sum4 = 0.0, sum5 = 0.0, sum6 = 0.0, sum7 = 0.0, sum8 = 0.0; int loops = (n / 8) * 8; float d1, d2, d3, d4, d5, d6, d7, d8; for (int i = 0; i < loops; i += 8) { d1 = v0[i] - v1[i]; d2 = v0[i + 1] - v1[i + 1]; d3 = v0[i + 2] - v1[i + 2]; d4 = v0[i + 3] - v1[i + 3]; d5 = v0[i + 4] - v1[i + 4]; d6 = v0[i + 5] - v1[i + 5]; d7 = v0[i + 6] - v1[i + 6]; d8 = v0[i + 7] - v1[i + 7]; sum1 += d1 * d1; sum2 += d2 * d2; sum3 += d3 * d3; sum4 += d4 * d4; sum5 += d5 * d5; sum6 += d6 * d6; sum7 += d7 * d7; sum8 += d8 * d8; } sum = sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7 + sum8; float delta = 0.0; for (int i = loops; i < n; ++i) { delta = v0[i] - v1[i]; sum += delta * delta; } return sum;}inline float distance4(float* v0, float* v1, int n) { float sum = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0, sum4 = 0.0; int loops = (n / 4) * 4; float d1, d2, d3, d4; for (int i = 0; i < loops; i += 4) { d1 = v0[i] - v1[i]; d2 = v0[i + 1] - v1[i + 1]; d3 = v0[i + 2] - v1[i + 2]; d4 = v0[i + 3] - v1[i + 3]; sum1 += d1 * d1; sum2 += d2 * d2; sum3 += d3 * d3; sum4 += d4 * d4; } sum = sum1 + sum2 + sum3 + sum4; float delta = 0.0; for (int i = loops; i < n; ++i) { delta = v0[i] - v1[i]; sum += delta * delta; } return sum;}inline float distance_SIMD(float* pVect1, float* pVect2, int qty) { int qty4 = qty/4; int qty16 = qty/16; const float* pEnd1 = pVect1 + 16 * qty16; const float* pEnd2 = pVect1 + 4 * qty4; const float* pEnd3 = pVect1 + qty; __m128 diff, v1, v2; __m128 sum = _mm_set1_ps(0); while (pVect1 < pEnd1) { v1 = _mm_loadu_ps(pVect1); pVect1 += 4; v2 = _mm_loadu_ps(pVect2); pVect2 += 4; diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); v1 = _mm_loadu_ps(pVect1); pVect1 += 4; v2 = _mm_loadu_ps(pVect2); pVect2 += 4; diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); v1 = _mm_loadu_ps(pVect1); pVect1 += 4; v2 = _mm_loadu_ps(pVect2); pVect2 += 4; diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); v1 = _mm_loadu_ps(pVect1); pVect1 += 4; v2 = _mm_loadu_ps(pVect2); pVect2 += 4; diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } while (pVect1 < pEnd2) { v1 = _mm_loadu_ps(pVect1); pVect1 += 4; v2 = _mm_loadu_ps(pVect2); pVect2 += 4; diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } float PORTABLE_ALIGN16 TmpRes[4]; _mm_store_ps(TmpRes, sum); float res= TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; while (pVect1 < pEnd3) { float diff = *pVect1++ - *pVect2++; res += diff * diff; } return res;}inline float distance1(float* v0, float* v1, int n) { float sum = 0.0; float delta = 0.0; for (int i = 0; i < n; ++i) { delta = v0[i] - v1[i]; sum += delta * delta; } return sum;}void test() { int cnt = 100000000; int dimensions = 300; float sum = 0.0; float v0[dimensions]; float v1[dimensions]; for (int i = 0; i < dimensions; ++i) { v0[i] = rand() / (RAND_MAX + 1.0) * 1000000; v1[i] = rand() / (RAND_MAX + 1.0) * 1000000; } clock_t old = clock(); for (int i = 0; i < cnt; ++i) { sum += distance4(v0, v1, dimensions); } clock_t n = clock(); cout << "Sum = " << sum << ", Using " << (n - old) * 1000 / CLOCKS_PER_SEC << " ticks." << endl;}int main(int argc, char** argv) { srand( (unsigned)time( NULL )); cout << "Warming up..." << endl; test(); cout << "Testing..." << endl; test(); return 0;}
java中distance1为1090ns,distance4和distance8为630ns,不如java版本快,即使代码用了intrincs.h中的内联函数,同样如此。
启用mavx等编译选项,没有变化。
启用o2或o3之后,distance4的正常和avx版本性能同样都是90ns。
jni的overhead过多,决定使用CriticalNative,目前测下来java版本比原始的nmslib更高效。