luckay-knows-demo/knows-java/src/main/java/cn/luckday/service/EsDocumentService.java
liushuang 98d7406a99 init
2025-03-05 14:14:54 +08:00

143 lines
5.6 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cn.luckday.service;
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.Script;
import co.elastic.clients.elasticsearch._types.query_dsl.*;
import co.elastic.clients.elasticsearch.core.IndexResponse;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.indices.CreateIndexRequest;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import co.elastic.clients.json.JsonData;
import cn.luckday.bean.SearchResult;
import cn.luckday.bean.KnowsIndex;
import cn.luckday.embed.EmbedClient;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@Slf4j
@Service
public class EsDocumentService {
@Value("${embedding.uri}")
private String embeddingUri;
@Value("${embedding.api-key}")
private String embeddingApiKey;
@Resource
private ElasticsearchClient client;
public static final String INDEX_NAME = "knows_index";
public static final float SIMILARITY_THRESHOLD = 0.2f;
/**
* 创建索引
* @throws IOException 异常
*/
public void createIndex() throws IOException {
CreateIndexRequest request = new CreateIndexRequest.Builder()
.index(INDEX_NAME)
.mappings(m -> m
.properties("file_name", p -> p.keyword(k -> k))
.properties("file_path", p -> p.keyword(k -> k))
.properties("file_type", p -> p.keyword(k -> k))
.properties("file_size", p -> p.keyword(k -> k))
.properties("remark_vec", p -> p
.denseVector(dv -> dv
.dims(1024)
.index(true)
.similarity("cosine")
)
)
.properties("remark", p -> p
.text(t -> t)
)
// .properties("remark", p -> p
// .text(t -> t.searchAnalyzer("ik_smart")
// .analyzer("ik_smart") // 使用 IK 分词器
// )
// )
)
.build();
CreateIndexResponse createIndexResponse = client.indices().create(request);
log.info("Index created: {}", createIndexResponse.acknowledged());
}
/**
* 添加数据
* @param knowsIndexList 数据
* @throws IOException 异常
*/
public void indexSellList(List<KnowsIndex> knowsIndexList) throws IOException {
for (KnowsIndex knowsIndex : knowsIndexList) {
knowsIndex.setContent_vec(EmbedClient.getEmbedding(embeddingUri, embeddingApiKey, knowsIndex.getContent()));
IndexResponse response = client.index(i -> i
.index(INDEX_NAME)
.id(knowsIndex.getId())
.document(knowsIndex)
);
log.info("Sell indexed: {}", response.id());
}
}
/**
* 检索
*
* @param queryVector 向量
*/
public List<SearchResult> searchVector(double[] queryVector) throws IOException {
// 创建向量相似度查询
ScriptScoreQuery scriptScoreQuery = ScriptScoreQuery.of(q -> q
.query(QueryBuilders.matchAll().build()._toQuery())
.script(Script.of(s -> s.inline(i -> i
.source("double score = cosineSimilarity(params.query_vector, 'content_vec'); " +
"score = Math.min(1.0, Math.max(0.0, score)); " + // 确保评分在[0, 1]之间
"if (score < params.threshold) { return 0; } else { return score; }")
.params(Map.of(
"query_vector", JsonData.of(queryVector),
"threshold", JsonData.of(SIMILARITY_THRESHOLD) // 将阈值作为参数传递给脚本
))))));
// 创建bool查询向量相似度查询作为should子句
Query boolQuery = QueryBuilders.bool(b -> b
.should(scriptScoreQuery._toQuery())
);
Query functionScoreQuery = QueryBuilders.functionScore(fs -> fs
.query(boolQuery)
.scoreMode(FunctionScoreMode.Max)
.boostMode(FunctionBoostMode.Replace)
.minScore((double) SIMILARITY_THRESHOLD)
);
// 执行合并后的查询
SearchResponse<KnowsIndex> combinedSearchResponse = client.search(s -> s
.index(INDEX_NAME)
.query(functionScoreQuery),
KnowsIndex.class);
// 处理查询的结果
return combinedSearchResponse.hits().hits().stream()
.map(hit -> {
double finalScore = Objects.nonNull(hit.score()) ? hit.score() : 0.0;
return finalScore >= SIMILARITY_THRESHOLD ? new SearchResult(hit.source(), finalScore) : null;
})
.filter(Objects::nonNull)
.sorted(Comparator.comparingDouble(SearchResult::getScore).reversed())
.collect(Collectors.toList());
}
}