向量检索Faiss实战

faiss简介

Faiss is a library for efficient similarity search and clustering of dense vectors。

官方介绍: Faiss是一个用于高效相似性搜索和密集向量聚类的库。也就是用来实现高效的向量检索。

Faiss主要组件包括:

  1. 索引结构:Flat(暴力搜索) 、IVF(Inverted File)、IVFPQ(Inverted File with Product Quantization)、HNSW(Hierarchical Navigable Small World),索引结构可以加速相似性搜索,降低查询时间。
  2. 向量编码:PQ(Product Quantization)、OPQ(Optimized Product Quantization)。编码可以将高维向量映射到低维空间中,同时保持距离的相似性,有助于减少内存占用和计算量。
  3. 相似性度量:欧氏距离、内积、Jaccard 相似度等。

Faiss的核心API有:

  1. IndexFactory(d int, description string, metric int):用来创建索引,通过维度,索引方法描述,相似性度量来创建索引。

  2. Ntotal() 索引向量的数量。

  3. Train(x []float32) 用一组具有代表性的向量训练索引。

  4. Add(x []float32),用于创建向量检索集。

  5. Search(x []float32, k int64) (distances []float32, labels []int64, err error),x向量在k紧邻进行检索,返回每个查询向量的 k 个最近邻的 ID 以及相应的距离。

如何理解AddSearch方法呢?Add是添加向量,Search从向量中检索。比如一篇文章拆分成5个片段,此时调用Add方法生成了5个向量,查询的内容会生成一个查询向量,那么Searchk=2会返回最近的两个近邻,也就是返回5个向量中的2个向量,那么返回值distances是查询向量到返回2个向量的距离,返回值labels是返回的向量在5个片段中的位置,此时就可以知道返回了那些段。

Faiss的主要流程是:

  1. 初始化索引结构,指定相似性度量方法(metric)和编码方法(description)。使用IndexFactory
  2. 将原始向量数据添加到索引中。使用Add
  3. 对查询向量进行编码,并在索引中搜索与查询向量相似的向量。使用Search
  4. 获取搜索结果,并根据需要进行后处理。

文档向量化检索设计

如果我们要实现一篇文档的向量化检索该如何设计呢?可以使用mysql和内存缓存作为文档的向量存储,方案可以先将文档拆分,然后存储到数据库中,表设计如下:

mysql存储拆分后的文档—— primary_id,edoc_part_content,project_id,embedding

内存缓存存储向量位置到文档主键ID——键:project_id+Ntotal() 值:primary_id

服务启动初始化时候从mysql加载doc表,获取到所有的文档,然后通过Add方法加载到检索集中,每加一次,调用Ntotal方法获取当前向量总数,也就是当前向量数组的位置下标,存入内存缓存中,

查询时候,生成查询向量后,调用Search方法,获取到检索集位置,然后获取从内存缓存中获取mysql中的主键id,去mysql查询到文档的内容。

Faiss配置指南

相似性计算方法

相似性计算主要有余弦,L1,L2等计算方法。

InnerProduct内积/余弦相似度

L1 曼哈顿距离

L2 欧氏距离

Linf 无穷范数

Lp p范数

Canberra BC相异度

BrayCurtis 兰氏距离/堪培拉距离

JensenShannon JS散度

索引方法

索引描述主要是向量检索算法。主要有以下几个:

Flat:最基础的索引结构,比较精确

IVF:Inverted File 倒排文件

PQ:Product Quantization 乘积量化

PCA:Principal Component Analysis 主成分分析

HNSW:Hierarchical Navigable Small World 分层的可导航小世界

相似性计算方法 索引描述 说明
InnerProduct Flat 余弦相似度 暴力检索
InnerProduct IVF100,Flat 余弦相似度 k-means聚类中心为100倒排(IVFx)暴力检索
L2 Flat 欧式距离 暴力检索
InnerProduct PQ16 余弦相似度 乘积量化 利用乘积量化的方法,改进了普通检索,将一个向量的维度切成x段,每段分别进行检索,每段向量的检索结果取交集后得出最后的TopK。因此速度很快,而且占用内存较小,召回率也相对较高
L2 PCA32,IVF100,PQ16 欧式距离 将向量先降维成32维,再用IVF100 PQ16的方法构建索引
L2 PCA32,HNSW32 欧式距离 处理HNSW内存占用过大的问题
L2 IVF100,PQ16 欧式距离 倒排乘积量化:工业界大量使用此方法,各项指标都均可以接受,利用乘积量化的方法,改进了IVF的k-means,将一个向量的维度切成x段,每段分别进行k-means再检索
其他 其他 大家自己枚举调优吧,采用下文测试方法测试是否成功

GoLang代码例子

faiss本身用C++实现,这里使用go-faiss来实现例子,embeding获取通过openai的接口实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
package services

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"math/rand"

gofaiss "github.com/DataIntelligenceCrew/go-faiss"
"github.com/sashabaranov/go-openai"
"github.com/spf13/cast"
)

const (
AuthToken = "openai的token"
)

var MetricTypeMap = map[string]int{
"InnerProduct": gofaiss.MetricInnerProduct, // 0
"L2": gofaiss.MetricL2, // 1
"L1": gofaiss.MetricL1, // 2
"Linf": gofaiss.MetricLinf, // 3
"Lp": gofaiss.MetricLp, // 4
"Canberra": gofaiss.MetricCanberra, // 20
"BrayCurtis": gofaiss.MetricBrayCurtis, // 21
"JensenShannon": gofaiss.MetricJensenShannon, // 22
}

type FaissReq struct {
IsDemo bool `json:"is_demo"`
DBSize int `json:"db_size"`
QuerySize int `json:"query_size"`
KNearest int64 `json:"k_nearest"`
Question string `json:"question"`
Model string `json:"model"`
Embedding []float32 `json:"embedding"`
Dimension int `json:"dimension"` // 维度
Description string `json:"description"` // 索引描述
Metric string `json:"metric"` // 相似性度量方法
}

type FaissRsp struct {
IsTrained bool `json:"is_trained"`
Ntotal int64 `json:"n_total"`
Dimension int `json:"dimension"`
}

type EmbeddingReq struct {
Prompt string `json:"prompt"`
Model string `json:"model"`
}

type EmbeddingRsp struct {
Embedding []float32 `json:"embedding"`
Time string `json:"time"`
}

func QueryFaiss(req FaissReq) (rsp FaissRsp, err error) {
log.Printf("all metrics is:%+v", MetricTypeMap)
if req.IsDemo {
d := req.Dimension // 向量维度
nb := req.DBSize // 全部数据大小
nq := req.QuerySize //

// 所有数据的向量
xb := make([]float32, d*nb)
// 查询数据的向量
xq := make([]float32, d*nq)

// 初始化全部数据
for i := 0; i < nb; i++ {
for j := 0; j < d; j++ {
xb[i*d+j] = rand.Float32()
}
xb[i*d] += float32(i) / 1000
}

// 初始化查询数据
for i := 0; i < nq; i++ {
for j := 0; j < d; j++ {
xq[i*d+j] = rand.Float32()
}
xq[i*d] += float32(i) / 1000
}

// 初始化Faiss
indexImpl, err := gofaiss.IndexFactory(d, req.Description, MetricTypeMap[req.Metric])
if err != nil {
log.Printf("IndexFactory err:%+v", err)
return FaissRsp{}, err
}
defer indexImpl.Delete()

// 训练全部数据
trainErr := indexImpl.Train(xb)
if trainErr != nil {
log.Printf("Train err:%+v", trainErr)
return FaissRsp{}, err
}
// 将全部数据加入Faiss中
addErr := indexImpl.Add(xb)
if err != nil {
log.Printf("addErr err:%+v", addErr)
return FaissRsp{}, err
}
k := int64(4)

// 合法性检查,用全部数据的前5*维度个
dist, ids, err := indexImpl.Search(xb[:5*d], k)
if err != nil {
log.Printf("Search err:%+v", err)
return FaissRsp{}, err
}
log.Printf("Search dist:%+v,ids:%+v", dist, ids)

fmt.Println("ids=")
for i := int64(0); i < 5; i++ {
for j := int64(0); j < k; j++ {
fmt.Printf("%5d ", ids[i*k+j])
}
fmt.Println()
}

fmt.Println("dist=")
for i := int64(0); i < 5; i++ {
for j := int64(0); j < k; j++ {
fmt.Printf("%7.6g ", dist[i*k+j])
}
fmt.Println()
}

// 通过查询数据xq进行向量检索
ps, err := gofaiss.NewParameterSpace()
if err != nil {
log.Printf("NewParameterSpace err:%+v", err)
return FaissRsp{}, err
}
defer ps.Delete()

if err := ps.SetIndexParameter(indexImpl, "nprobe", 10); err != nil {
log.Printf("SetIndexParameter err:%+v", err)
return FaissRsp{}, err
}

_, ids, err = indexImpl.Search(xq, k)
if err != nil {
log.Printf(" indexImpl.Search Last err:%+v", err)
return FaissRsp{}, err
}

fmt.Println("ids (last 5 results)=")
for i := int64(nq) - 5; i < int64(nq); i++ {
for j := int64(0); j < k; j++ {
fmt.Printf("%5d ", ids[i*k+j])
}
fmt.Println()
}
return FaissRsp{}, nil
}

indexImpl, err := gofaiss.IndexFactory(req.Dimension, req.Description, MetricTypeMap[req.Metric])
if err != nil {
log.Printf("IndexFactory error:%+v,req:%+v", err, req)
return FaissRsp{}, err
}

var embeddingArray []float32
if len(req.Question) != 0 {
embedding, embeddingErr := Embedding(context.Background(), EmbeddingReq{Prompt: req.Question, Model: req.Model})
if embeddingErr != nil {
log.Printf("Embedding err:%+v", embeddingErr)
return FaissRsp{}, embeddingErr
}
embeddingArray = embedding.Embedding
} else {
embeddingArray = req.Embedding
}

log.Printf("embedding is:%s", jsonString(embeddingArray))

err = indexImpl.Train(embeddingArray)
if err != nil {
log.Printf("indexImpl.Train error:%+v,req:%+v", err, req)
return FaissRsp{}, err
}
err = indexImpl.Add(embeddingArray)
if err != nil {
log.Printf("indexImpl.Add error:%+v,req:%+v", err, req)
return FaissRsp{}, err
}

dist, ids, err := indexImpl.Search(embeddingArray, req.KNearest)
if err != nil {
log.Printf("Search err:%+v", err)
return FaissRsp{}, err
}
log.Printf("Search dist:%s,\n ids:%s", jsonString(dist), jsonInt64String(ids))

return FaissRsp{IsTrained: indexImpl.IsTrained(), Ntotal: indexImpl.Ntotal(), Dimension: indexImpl.D()}, nil
}

func jsonString(data []float32) string {
marshal, _ := json.Marshal(data)
return string(marshal)
}

func jsonInt64String(data []int64) string {
marshal, _ := json.Marshal(data)
return string(marshal)
}

// Embedding 根据openai获取embedding
func Embedding(ctx context.Context, req EmbeddingReq) (rsp EmbeddingRsp, err error) {

var model openai.EmbeddingModel
if len(req.Model) == 0 {
model = openai.AdaEmbeddingV2
} else {
model = openai.EmbeddingModel(cast.ToInt(req.Model))
}

cfg := openai.DefaultConfig(AuthToken)
cfg.BaseURL = "https://api.aiproxy.io/v1"
client := openai.NewClientWithConfig(cfg)
resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Input: req.Prompt,
Model: model,
})
if err != nil {
log.Printf("Embedding error: %v,question:%s,model:%s,resp:%+v", err, req.Prompt, model)
return rsp, err
}

if len(resp.Data) > 0 {
return EmbeddingRsp{
Embedding: resp.Data[0].Embedding,
Time: "",
}, nil
}
return rsp, errors.New("没有Embeddings")
}

参考

https://github.com/facebookresearch/faiss

https://github.com/DataIntelligenceCrew/go-faiss

https://zhuanlan.zhihu.com/p/357414033

https://guangzhengli.com/blog/zh/vector-database/

https://faiss.ai/index.html

https://github.com/sashabaranov/go-openai/blob/master/embeddings.go

https://platform.openai.com/docs/guides/embeddings

https://openai.com/blog/new-and-improved-embedding-model