简介
在计算向量相似度时,常用 近似最近邻(ANN, Approximate Nearest Neighbor)算法 来加速查询向量的搜索。其中,较为知名的 ANN 算法包括 HNSW、Ivfflat、Ivfpq 和 Ivfsq。在 IVF(倒排索引,Inverted File Index) 类型的算法中,Elkan K-Means 算法是较为经典的方法之一,并被广泛用于向量聚类和索引构建。
在 PostgreSQL 中,pgvector 插件提供了对向量数据的索引与搜索支持,而 Elkan K-Means 算法正是其中用于优化 IVF 聚类过程的关键技术。接下来,我们将深入解析 Elkan K-Means 算法的具体执行流程。
算法详情
C语言编写的完整代码过程(其中的宏和其他的内容暂不解释,主要讲解代码逻辑):
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfTypeInfo * typeInfo)
{
FmgrInfo *procinfo;
FmgrInfo *normprocinfo;
Oid collation;
int dimensions = centers->dim;
int numCenters = centers->maxlen;
int numSamples = samples->length;
VectorArray newCenters;
float *agg;
int *centerCounts;
int *closestCenters;
float *lowerBound;
float *upperBound;
float *s;
float *halfcdist;
float *newcdist;
/* Calculate allocation sizes */
Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);
Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);
Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);
Size aggSize = sizeof(float) * (int64) numCenters * dimensions;
Size centerCountsSize = sizeof(int) * numCenters;
Size closestCentersSize = sizeof(int) * numSamples;
Size lowerBoundSize = sizeof(float) * numSamples * numCenters;
Size upperBoundSize = sizeof(float) * numSamples;
Size sSize = sizeof(float) * numCenters;
Size halfcdistSize = sizeof(float) * numCenters * numCenters;
Size newcdistSize = sizeof(float) * numCenters;
/* Calculate total size */
Size totalSize = samplesSize + centersSize + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;
/* Check memory requirements */
/* Add one to error message to ceil */
if (totalSize > (Size) maintenance_work_mem * 1024L)
ereport(ERROR,
(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",
totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));
/* Ensure indexing does not overflow */
if (numCenters * numCenters > INT_MAX)
elog(ERROR, "Indexing overflow detected. Please report a bug.");
/* Set support functions */
procinfo = index_getprocinfo(index, 1, IVF_KMEANS_DISTANCE_PROC);
normprocinfo = IvfOptionalProcInfo(index, IVF_KMEANS_NORM_PROC);
collation = index->rd_indcollation[0];
/* Allocate space */
/* Use float instead of double to save memory */
agg = palloc(aggSize);
centerCounts = palloc(centerCountsSize);
closestCenters = palloc(closestCentersSize);
lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);
upperBound = palloc(upperBoundSize);
s = palloc(sSize);
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
newcdist = palloc(newcdistSize);
/* Initialize new centers */
newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);
newCenters->length = numCenters;
#ifdef IVFFLAT_MEMORY
ShowMemoryUsage(MemoryContextGetParent(CurrentMemoryContext));
#elif defined IVFPQ_MEMORY
ShowMemoryUsage(MemoryContextGetParent(CurrentMemoryContext));
#endif
/* Pick initial centers */
InitCenters(index, samples, centers, lowerBound);
/* Assign each x to its closest initial center c(x) = argmin d(x,c) */
for (int64 j = 0; j < numSamples; j++)
{
float minDistance = FLT_MAX;
int closestCenter = 0;
/* Find closest center */
for (int64 k = 0; k < numCenters; k++)
{
/* TODO Use Lemma 1 in k-means++ initialization */
float distance = lowerBound[j * numCenters + k];
if (distance < minDistance)
{
minDistance = distance;
closestCenter = k;
}
}
upperBound[j] = minDistance;
closestCenters[j] = closestCenter;
}
/* Give 500 iterations to converge */
for (int iteration = 0; iteration < 500; iteration++)
{
int changes = 0;
bool rjreset;
/* Can take a while, so ensure we can interrupt */
CHECK_FOR_INTERRUPTS();
/* Step 1: For all centers, compute distance */
for (int64 j = 0; j < numCenters; j++)
{
Datum vec = PointerGetDatum(VectorArrayGet(centers, j));
for (int64 k = j + 1; k < numCenters; k++)
{
float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, k))));
halfcdist[j * numCenters + k] = distance;
halfcdist[k * numCenters + j] = distance;
}
}
/* For all centers c, compute s(c) */
for (int64 j = 0; j < numCenters; j++)
{
float minDistance = FLT_MAX;
for (int64 k = 0; k < numCenters; k++)
{