package cn.luckday.service; import co.elastic.clients.elasticsearch.ElasticsearchClient; import co.elastic.clients.elasticsearch._types.Script; import co.elastic.clients.elasticsearch._types.query_dsl.*; import co.elastic.clients.elasticsearch.core.IndexResponse; import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.indices.CreateIndexRequest; import co.elastic.clients.elasticsearch.indices.CreateIndexResponse; import co.elastic.clients.json.JsonData; import cn.luckday.bean.SearchResult; import cn.luckday.bean.KnowsIndex; import cn.luckday.embed.EmbedClient; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.io.IOException; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @Slf4j @Service public class EsDocumentService { @Value("${embedding.uri}") private String embeddingUri; @Value("${embedding.api-key}") private String embeddingApiKey; @Resource private ElasticsearchClient client; public static final String INDEX_NAME = "knows_index"; public static final float SIMILARITY_THRESHOLD = 0.2f; /** * 创建索引 * @throws IOException 异常 */ public void createIndex() throws IOException { CreateIndexRequest request = new CreateIndexRequest.Builder() .index(INDEX_NAME) .mappings(m -> m .properties("file_name", p -> p.keyword(k -> k)) .properties("file_path", p -> p.keyword(k -> k)) .properties("file_type", p -> p.keyword(k -> k)) .properties("file_size", p -> p.keyword(k -> k)) .properties("remark_vec", p -> p .denseVector(dv -> dv .dims(1024) .index(true) .similarity("cosine") ) ) .properties("remark", p -> p .text(t -> t) ) // .properties("remark", p -> p // .text(t -> t.searchAnalyzer("ik_smart") // .analyzer("ik_smart") // 使用 IK 分词器 // ) // ) ) .build(); CreateIndexResponse createIndexResponse = client.indices().create(request); log.info("Index created: {}", createIndexResponse.acknowledged()); } /** * 添加数据 * @param knowsIndexList 数据 * @throws IOException 异常 */ public void indexSellList(List knowsIndexList) throws IOException { for (KnowsIndex knowsIndex : knowsIndexList) { knowsIndex.setContent_vec(EmbedClient.getEmbedding(embeddingUri, embeddingApiKey, knowsIndex.getContent())); IndexResponse response = client.index(i -> i .index(INDEX_NAME) .id(knowsIndex.getId()) .document(knowsIndex) ); log.info("Sell indexed: {}", response.id()); } } /** * 检索 * * @param queryVector 向量 */ public List searchVector(double[] queryVector) throws IOException { // 创建向量相似度查询 ScriptScoreQuery scriptScoreQuery = ScriptScoreQuery.of(q -> q .query(QueryBuilders.matchAll().build()._toQuery()) .script(Script.of(s -> s.inline(i -> i .source("double score = cosineSimilarity(params.query_vector, 'content_vec'); " + "score = Math.min(1.0, Math.max(0.0, score)); " + // 确保评分在[0, 1]之间 "if (score < params.threshold) { return 0; } else { return score; }") .params(Map.of( "query_vector", JsonData.of(queryVector), "threshold", JsonData.of(SIMILARITY_THRESHOLD) // 将阈值作为参数传递给脚本 )))))); // 创建bool查询,向量相似度查询作为should子句 Query boolQuery = QueryBuilders.bool(b -> b .should(scriptScoreQuery._toQuery()) ); Query functionScoreQuery = QueryBuilders.functionScore(fs -> fs .query(boolQuery) .scoreMode(FunctionScoreMode.Max) .boostMode(FunctionBoostMode.Replace) .minScore((double) SIMILARITY_THRESHOLD) ); // 执行合并后的查询 SearchResponse combinedSearchResponse = client.search(s -> s .index(INDEX_NAME) .query(functionScoreQuery), KnowsIndex.class); // 处理查询的结果 return combinedSearchResponse.hits().hits().stream() .map(hit -> { double finalScore = Objects.nonNull(hit.score()) ? hit.score() : 0.0; return finalScore >= SIMILARITY_THRESHOLD ? new SearchResult(hit.source(), finalScore) : null; }) .filter(Objects::nonNull) .sorted(Comparator.comparingDouble(SearchResult::getScore).reversed()) .collect(Collectors.toList()); } }