143 lines
5.6 KiB
Java
143 lines
5.6 KiB
Java
package cn.luckday.service;
|
||
|
||
|
||
import co.elastic.clients.elasticsearch.ElasticsearchClient;
|
||
import co.elastic.clients.elasticsearch._types.Script;
|
||
import co.elastic.clients.elasticsearch._types.query_dsl.*;
|
||
import co.elastic.clients.elasticsearch.core.IndexResponse;
|
||
import co.elastic.clients.elasticsearch.core.SearchResponse;
|
||
import co.elastic.clients.elasticsearch.indices.CreateIndexRequest;
|
||
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
|
||
import co.elastic.clients.json.JsonData;
|
||
import cn.luckday.bean.SearchResult;
|
||
import cn.luckday.bean.KnowsIndex;
|
||
import cn.luckday.embed.EmbedClient;
|
||
import jakarta.annotation.Resource;
|
||
import lombok.extern.slf4j.Slf4j;
|
||
import org.springframework.beans.factory.annotation.Value;
|
||
import org.springframework.stereotype.Service;
|
||
|
||
import java.io.IOException;
|
||
import java.util.Comparator;
|
||
import java.util.List;
|
||
import java.util.Map;
|
||
import java.util.Objects;
|
||
import java.util.stream.Collectors;
|
||
|
||
@Slf4j
|
||
@Service
|
||
public class EsDocumentService {
|
||
|
||
@Value("${embedding.uri}")
|
||
private String embeddingUri;
|
||
|
||
@Value("${embedding.api-key}")
|
||
private String embeddingApiKey;
|
||
|
||
@Resource
|
||
private ElasticsearchClient client;
|
||
|
||
public static final String INDEX_NAME = "knows_index";
|
||
|
||
public static final float SIMILARITY_THRESHOLD = 0.2f;
|
||
|
||
/**
|
||
* 创建索引
|
||
* @throws IOException 异常
|
||
*/
|
||
public void createIndex() throws IOException {
|
||
CreateIndexRequest request = new CreateIndexRequest.Builder()
|
||
.index(INDEX_NAME)
|
||
|
||
.mappings(m -> m
|
||
.properties("file_name", p -> p.keyword(k -> k))
|
||
.properties("file_path", p -> p.keyword(k -> k))
|
||
.properties("file_type", p -> p.keyword(k -> k))
|
||
.properties("file_size", p -> p.keyword(k -> k))
|
||
.properties("remark_vec", p -> p
|
||
.denseVector(dv -> dv
|
||
.dims(1024)
|
||
.index(true)
|
||
.similarity("cosine")
|
||
)
|
||
)
|
||
.properties("remark", p -> p
|
||
.text(t -> t)
|
||
)
|
||
// .properties("remark", p -> p
|
||
// .text(t -> t.searchAnalyzer("ik_smart")
|
||
// .analyzer("ik_smart") // 使用 IK 分词器
|
||
// )
|
||
// )
|
||
)
|
||
.build();
|
||
|
||
CreateIndexResponse createIndexResponse = client.indices().create(request);
|
||
log.info("Index created: {}", createIndexResponse.acknowledged());
|
||
}
|
||
|
||
/**
|
||
* 添加数据
|
||
* @param knowsIndexList 数据
|
||
* @throws IOException 异常
|
||
*/
|
||
public void indexSellList(List<KnowsIndex> knowsIndexList) throws IOException {
|
||
for (KnowsIndex knowsIndex : knowsIndexList) {
|
||
knowsIndex.setContent_vec(EmbedClient.getEmbedding(embeddingUri, embeddingApiKey, knowsIndex.getContent()));
|
||
IndexResponse response = client.index(i -> i
|
||
.index(INDEX_NAME)
|
||
.id(knowsIndex.getId())
|
||
.document(knowsIndex)
|
||
);
|
||
log.info("Sell indexed: {}", response.id());
|
||
}
|
||
}
|
||
|
||
|
||
/**
|
||
* 检索
|
||
*
|
||
* @param queryVector 向量
|
||
*/
|
||
public List<SearchResult> searchVector(double[] queryVector) throws IOException {
|
||
// 创建向量相似度查询
|
||
ScriptScoreQuery scriptScoreQuery = ScriptScoreQuery.of(q -> q
|
||
.query(QueryBuilders.matchAll().build()._toQuery())
|
||
.script(Script.of(s -> s.inline(i -> i
|
||
.source("double score = cosineSimilarity(params.query_vector, 'content_vec'); " +
|
||
"score = Math.min(1.0, Math.max(0.0, score)); " + // 确保评分在[0, 1]之间
|
||
"if (score < params.threshold) { return 0; } else { return score; }")
|
||
.params(Map.of(
|
||
"query_vector", JsonData.of(queryVector),
|
||
"threshold", JsonData.of(SIMILARITY_THRESHOLD) // 将阈值作为参数传递给脚本
|
||
))))));
|
||
|
||
// 创建bool查询,向量相似度查询作为should子句
|
||
Query boolQuery = QueryBuilders.bool(b -> b
|
||
.should(scriptScoreQuery._toQuery())
|
||
);
|
||
|
||
Query functionScoreQuery = QueryBuilders.functionScore(fs -> fs
|
||
.query(boolQuery)
|
||
.scoreMode(FunctionScoreMode.Max)
|
||
.boostMode(FunctionBoostMode.Replace)
|
||
.minScore((double) SIMILARITY_THRESHOLD)
|
||
);
|
||
|
||
// 执行合并后的查询
|
||
SearchResponse<KnowsIndex> combinedSearchResponse = client.search(s -> s
|
||
.index(INDEX_NAME)
|
||
.query(functionScoreQuery),
|
||
KnowsIndex.class);
|
||
|
||
// 处理查询的结果
|
||
return combinedSearchResponse.hits().hits().stream()
|
||
.map(hit -> {
|
||
double finalScore = Objects.nonNull(hit.score()) ? hit.score() : 0.0;
|
||
return finalScore >= SIMILARITY_THRESHOLD ? new SearchResult(hit.source(), finalScore) : null;
|
||
})
|
||
.filter(Objects::nonNull)
|
||
.sorted(Comparator.comparingDouble(SearchResult::getScore).reversed())
|
||
.collect(Collectors.toList());
|
||
}
|
||
} |