Browse Source

Merge pull request #16 from h794629435/main

本地向量化
ageerle 1 month ago
parent
commit
a63cc48789

+ 38 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java

@@ -0,0 +1,38 @@
+package org.ruoyi.common.chat.entity.models;
+
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * @program: RUOYIAI
+ * @ClassName LocalModelsSearchRequest
+ * @description:
+ * @author: hejh
+ * @create: 2025-03-15 17:22
+ * @Version 1.0
+ **/
+@Data
+public class LocalModelsSearchRequest {
+
+    private List<String> text;
+    private String model_name;
+    private String delimiter;
+    private int k;
+    private int block_size;
+    private int overlap_chars;
+
+    // 构造函数、Getter 和 Setter
+    public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) {
+        this.text = text;
+        this.model_name = model_name;
+        this.delimiter = delimiter;
+        this.k = k;
+        this.block_size = block_size;
+        this.overlap_chars = overlap_chars;
+    }
+
+
+}
+
+

+ 20 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java

@@ -0,0 +1,20 @@
+package org.ruoyi.common.chat.entity.models;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
+
+import java.util.List;
+
+@Data
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class LocalModelsSearchResponse {
+    @JsonProperty("topKEmbeddings")
+
+    private List<List<List<Double>>> topKEmbeddings;  // 处理三层嵌套数组
+
+    // 默认构造函数
+    public LocalModelsSearchResponse() {}
+
+
+
+}

+ 198 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java

@@ -0,0 +1,198 @@
+package org.ruoyi.common.chat.localModels;
+
+import io.micrometer.common.util.StringUtils;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.OkHttpClient;
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
+import org.springframework.stereotype.Service;
+import retrofit2.Call;
+import retrofit2.Callback;
+import retrofit2.Response;
+import retrofit2.Retrofit;
+import retrofit2.converter.jackson.JacksonConverterFactory;
+
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+
+@Slf4j
+@Service
+public class LocalModelsofitClient {
+    private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL
+    private static Retrofit retrofit = null;
+
+    // 获取 Retrofit 实例
+    public static Retrofit getRetrofitInstance() {
+        if (retrofit == null) {
+            OkHttpClient client = new OkHttpClient.Builder()
+                    .build();
+
+            retrofit = new Retrofit.Builder()
+                    .baseUrl(BASE_URL)
+                    .client(client)
+                    .addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换
+                    .build();
+        }
+        return retrofit;
+    }
+
+    /**
+     * 向 Flask 服务发送文本向量化请求
+     *
+     * @param queries 查询文本列表
+     * @param modelName 模型名称
+     * @param delimiter 文本分隔符
+     * @param topK 返回的结果数
+     * @param blockSize 文本块大小
+     * @param overlapChars 重叠字符数
+     * @return 返回计算得到的 Top K 嵌入向量列表
+     */
+
+    public static List<List<Double>> getTopKEmbeddings(
+            List<String> queries,
+            String modelName,
+            String delimiter,
+            int topK,
+            int blockSize,
+            int overlapChars) {
+
+        modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称
+        delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : ".";                             // 默认分隔符
+        topK = (topK > 0) ? topK : 3;                                                  // 默认返回 3 个结果
+        blockSize = (blockSize > 0) ? blockSize : 500;                                 // 默认文本块大小为 500
+        overlapChars = (overlapChars > 0) ? overlapChars : 50;                         // 默认重叠字符数为 50
+
+        // 创建 Retrofit 实例
+        Retrofit retrofit = getRetrofitInstance();
+
+        // 创建 SearchService 接口
+        SearchService service = retrofit.create(SearchService.class);
+
+        // 创建请求对象 LocalModelsSearchRequest
+        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
+                queries,            // 查询文本列表
+                modelName,          // 模型名称
+                delimiter,          // 文本分隔符
+                topK,               // 返回的结果数
+                blockSize,          // 文本块大小
+                overlapChars        // 重叠字符数
+        );
+
+        final CountDownLatch latch = new CountDownLatch(1);  // 创建一个 CountDownLatch
+        final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List)
+
+        // 发起异步请求
+        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
+            @Override
+            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
+                if (response.isSuccessful()) {
+                    LocalModelsSearchResponse searchResponse = response.body();
+                    if (searchResponse != null) {
+                        topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0);  // 获取结果
+                        log.info("Successfully retrieved embeddings");
+                    } else {
+                        log.error("Response body is null");
+                    }
+                } else {
+                    log.error("Request failed. HTTP error code: " + response.code());
+                }
+                latch.countDown();  // 请求完成,减少计数
+            }
+
+            @Override
+            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
+                t.printStackTrace();
+                log.error("Request failed: ", t);
+                latch.countDown();  // 请求失败,减少计数
+            }
+        });
+
+        try {
+            latch.await();  // 等待请求完成
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+
+        return topKEmbeddings[0];  // 返回结果
+    }
+
+//    public static void main(String[] args) {
+//        // 示例调用
+//        List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
+//        String modelName = "msmarco-distilbert-base-tas-b";
+//        String delimiter = ".";
+//        int topK = 3;
+//        int blockSize = 500;
+//        int overlapChars = 50;
+//
+//        List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
+//
+//        // 打印结果
+//        if (topKEmbeddings != null) {
+//            System.out.println("Top K embeddings: ");
+//            for (List<Double> embedding : topKEmbeddings) {
+//                System.out.println(embedding);
+//            }
+//        } else {
+//            System.out.println("No embeddings returned.");
+//        }
+//    }
+
+
+//    public static void main(String[] args) {
+//        // 创建 Retrofit 实例
+//        Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
+//
+//        // 创建 SearchService 接口
+//        SearchService service = retrofit.create(SearchService.class);
+//
+//        // 创建请求对象 LocalModelsSearchRequest
+//        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
+//                Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表
+//                "msmarco-distilbert-base-tas-b",  // 模型名称
+//                ".",  // 分隔符
+//                3,  // 返回的结果数
+//                500,  // 文本块大小
+//                50  // 重叠字符数
+//        );
+//
+//        // 发起请求
+//        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
+//            @Override
+//            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
+//                if (response.isSuccessful()) {
+//                    LocalModelsSearchResponse searchResponse = response.body();
+//                    System.out.println("Response Body: " + response.body());  // Print the whole response body for debugging
+//
+//                    if (searchResponse != null) {
+//                        // If the response is not null, process it.
+//                        // Example: Extract the embeddings and print them
+//                        List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
+//                        if (topKEmbeddings != null) {
+//                            // Print the Top K embeddings
+//
+//                        } else {
+//                            System.err.println("Top K embeddings are null");
+//                        }
+//
+//                        // If there is more information you want to process, handle it here
+//
+//                    } else {
+//                        System.err.println("Response body is null");
+//                    }
+//                } else {
+//                    System.err.println("Request failed. HTTP error code: " + response.code());
+//                    log.error("Failed to retrieve data. HTTP error code: " + response.code());
+//                }
+//            }
+//
+//            @Override
+//            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
+//                // 请求失败,打印错误
+//                t.printStackTrace();
+//                log.error("Request failed: ", t);
+//            }
+//        });
+//    }
+
+}

+ 25 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java

@@ -0,0 +1,25 @@
+package org.ruoyi.common.chat.localModels;
+
+
+
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
+import retrofit2.Call;
+import retrofit2.http.Body;
+import retrofit2.http.POST;
+/**
+ * @program: RUOYIAI
+ * @ClassName SearchService
+ * @description: 请求模型
+ * @author: hejh
+ * @create: 2025-03-15 17:27
+ * @Version 1.0
+ **/
+
+
+public interface SearchService {
+    @POST("/vectorize") // 与 Flask 服务中的路由匹配
+    Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
+}
+
+

+ 92 - 0
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java

@@ -0,0 +1,92 @@
+package org.ruoyi.knowledge.chain.vectorizer;
+
+import jakarta.annotation.Resource;
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.ruoyi.common.chat.config.ChatConfig;
+import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
+import org.ruoyi.common.chat.openai.OpenAiStreamClient;
+import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
+import org.ruoyi.knowledge.service.IKnowledgeInfoService;
+import org.springframework.stereotype.Component;
+
+import java.util.ArrayList;
+import java.util.List;
+
+@Component
+@Slf4j
+@RequiredArgsConstructor
+public class LocalModelsVectorization   {
+    @Resource
+    private IKnowledgeInfoService knowledgeInfoService;
+
+    @Resource
+    private LocalModelsofitClient localModelsofitClient;
+
+    @Getter
+    private OpenAiStreamClient openAiStreamClient;
+
+    private final ChatConfig chatConfig;
+
+    /**
+     * 批量向量化
+     *
+     * @param chunkList 文本块列表
+     * @param kid 知识 ID
+     * @return 向量化结果
+     */
+
+    public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
+        logVectorizationRequest(kid, chunkList);  // 在向量化开始前记录日志
+        openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 获取 OpenAi 客户端
+        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 查询知识信息
+        // 调用 localModelsofitClient 获取 Top K 嵌入向量
+        try {
+            return localModelsofitClient.getTopKEmbeddings(
+                    chunkList,
+                    knowledgeInfoVo.getVector(),
+                    knowledgeInfoVo.getKnowledgeSeparator(),
+                    knowledgeInfoVo.getRetrieveLimit(),
+                    knowledgeInfoVo.getTextBlockSize(),
+                    knowledgeInfoVo.getOverlapChar()
+            );
+        } catch (Exception e) {
+            log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
+            throw new RuntimeException("Batch vectorization failed", e);
+        }
+    }
+
+    /**
+     * 单一文本块向量化
+     *
+     * @param chunk 单一文本块
+     * @param kid 知识 ID
+     * @return 向量化结果
+     */
+
+    public List<Double> singleVectorization(String chunk, String kid) {
+        List<String> chunkList = new ArrayList<>();
+        chunkList.add(chunk);
+
+        // 调用批量向量化方法
+        List<List<Double>> vectorList = batchVectorization(chunkList, kid);
+
+        if (vectorList.isEmpty()) {
+            log.warn("Vectorization returned empty list for chunk: {}", chunk);
+            return new ArrayList<>();
+        }
+
+        return vectorList.get(0); // 返回第一个向量
+    }
+
+    /**
+     * 提供更简洁的日志记录方法
+     *
+     * @param kid 知识 ID
+     * @param chunkList 文本块列表
+     */
+    private void logVectorizationRequest(String kid, List<String> chunkList) {
+        log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
+    }
+}

+ 51 - 9
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java

@@ -18,6 +18,7 @@ import org.springframework.stereotype.Component;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
 
 @Component
 @Slf4j
@@ -27,6 +28,9 @@ public class OpenAiVectorization implements Vectorization {
     @Lazy
     @Resource
     private IKnowledgeInfoService knowledgeInfoService;
+    @Lazy
+    @Resource
+    private LocalModelsVectorization localModelsVectorization;
 
     @Getter
     private OpenAiStreamClient openAiStreamClient;
@@ -35,25 +39,63 @@ public class OpenAiVectorization implements Vectorization {
 
     @Override
     public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
-        openAiStreamClient = chatConfig.getOpenAiStreamClient();
+        List<List<Double>> vectorList = new ArrayList<>();
+
+        // 获取知识库信息
         KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
-        Embedding embedding = Embedding.builder()
-            .input(chunkList)
-            .model(knowledgeInfoVo.getVectorModel())
-            .build();
+
+        // 如果使用本地模型
+        try {
+            return localModelsVectorization.batchVectorization(chunkList, kid);
+        } catch (Exception e) {
+            log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
+        }
+
+        // 如果本地模型失败,则调用 OpenAI 服务进行向量化
+        Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
         EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
+
+        // 处理 OpenAI 返回的嵌入数据
+        vectorList = processOpenAiEmbeddings(embeddings);
+
+        return vectorList;
+    }
+
+    /**
+     * 构建 Embedding 对象
+     */
+    private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
+        return Embedding.builder()
+                .input(chunkList)
+                .model(knowledgeInfoVo.getVectorModel())
+                .build();
+    }
+
+    /**
+     * 处理 OpenAI 返回的嵌入数据
+     */
+    private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
         List<List<Double>> vectorList = new ArrayList<>();
+
         embeddings.getData().forEach(data -> {
             List<BigDecimal> vector = data.getEmbedding();
-            List<Double> doubleVector = new ArrayList<>();
-            for (BigDecimal bd : vector) {
-                doubleVector.add(bd.doubleValue());
-            }
+            List<Double> doubleVector = convertToDoubleList(vector);
             vectorList.add(doubleVector);
         });
+
         return vectorList;
     }
 
+    /**
+     * 将 BigDecimal 转换为 Double 列表
+     */
+    private List<Double> convertToDoubleList(List<BigDecimal> vector) {
+        return vector.stream()
+                .map(BigDecimal::doubleValue)
+                .collect(Collectors.toList());
+    }
+
+
     @Override
     public List<Double> singleVectorization(String chunk, String kid) {
         List<String> chunkList = new ArrayList<>();

+ 15 - 0
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java

@@ -0,0 +1,15 @@
+package org.ruoyi.knowledge.chain.vectorizer;
+
+public enum VectorizationType {
+    OPENAI,    // OpenAI 向量化
+    LOCAL;     // 本地模型向量化
+
+    public static VectorizationType fromString(String type) {
+        for (VectorizationType v : values()) {
+            if (v.name().equalsIgnoreCase(type)) {
+                return v;
+            }
+        }
+        throw new IllegalArgumentException("Unknown VectorizationType: " + type);
+    }
+}

+ 21 - 0
script/docker/localModels/Dockerfile

@@ -0,0 +1,21 @@
+# 使用官方 Python 作为基础镜像
+FROM python:3.8-slim
+
+# 设置工作目录为 /app
+WORKDIR /app
+
+# 复制当前目录下的所有文件到 Docker 容器的 /app 目录
+COPY . /app
+
+# 安装应用依赖
+RUN pip install --no-cache-dir -r requirements.txt
+
+# 暴露 Flask 应用使用的端口
+EXPOSE 5000
+
+# 设置环境变量
+ENV FLASK_APP=app.py
+ENV FLASK_RUN_HOST=0.0.0.0
+
+# 启动 Flask 应用
+CMD ["flask", "run", "--host=0.0.0.0"]

+ 116 - 0
script/docker/localModels/app.py

@@ -0,0 +1,116 @@
+from flask import Flask, request, jsonify
+from sentence_transformers import SentenceTransformer
+from sklearn.metrics.pairwise import cosine_similarity
+import json
+
+app = Flask(__name__)
+
+# 创建一个全局的模型缓存字典
+model_cache = {}
+
+# 分割文本块
+def split_text(text, block_size, overlap_chars, delimiter):
+    chunks = text.split(delimiter)
+    text_blocks = []
+    current_block = ""
+
+    for chunk in chunks:
+        if len(current_block) + len(chunk) + 1 <= block_size:
+            if current_block:
+                current_block += " " + chunk
+            else:
+                current_block = chunk
+        else:
+            text_blocks.append(current_block)
+            current_block = chunk
+    if current_block:
+        text_blocks.append(current_block)
+
+    overlap_blocks = []
+    for i in range(len(text_blocks)):
+        if i > 0:
+            overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
+            overlap_blocks.append(overlap_block)
+        overlap_blocks.append(text_blocks[i])
+
+    return overlap_blocks
+
+# 文本向量化
+def vectorize_text_blocks(text_blocks, model):
+    return model.encode(text_blocks)
+
+# 文本检索
+def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
+    # 将知识库拆分为文本块
+    text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
+    # 向量化文本块
+    knowledge_vectors = vectorize_text_blocks(text_blocks, model)
+    # 向量化查询文本
+    query_vector = model.encode([query]).reshape(1, -1)
+    # 计算相似度
+    similarities = cosine_similarity(query_vector, knowledge_vectors)
+    # 获取相似度最高的 k 个文本块的索引
+    top_k_indices = similarities[0].argsort()[-k:][::-1]
+
+    # 返回文本块和它们的向量
+    top_k_texts = [text_blocks[i] for i in top_k_indices]
+    top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
+
+    return top_k_texts, top_k_embeddings
+
+@app.route('/vectorize', methods=['POST'])
+def vectorize_text():
+    # 从请求中获取 JSON 数据
+    data = request.json
+    print(f"Received request data: {data}")  # 调试输出请求数据
+
+    text_list = data.get("text", [])
+    model_name = data.get("model_name", "msmarco-distilbert-base-tas-b")  # 默认模型
+
+    delimiter = data.get("delimiter", "\n")  # 默认分隔符
+    k = int(data.get("k", 3))  # 默认检索条数
+    block_size = int(data.get("block_size", 500))  # 默认文本块大小
+    overlap_chars = int(data.get("overlap_chars", 50))  # 默认重叠字符数
+
+    if not text_list:
+        return jsonify({"error": "Text is required."}), 400
+
+    # 检查模型是否已经加载
+    if model_name not in model_cache:
+        try:
+            model = SentenceTransformer(model_name)
+            model_cache[model_name] = model  # 缓存模型
+        except Exception as e:
+            return jsonify({"error": f"Failed to load model: {e}"}), 500
+
+    model = model_cache[model_name]
+
+    top_k_texts_all = []
+    top_k_embeddings_all = []
+
+    # 如果只有一个查询文本
+    if len(text_list) == 1:
+        top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
+        top_k_texts_all.append(top_k_texts)
+        top_k_embeddings_all.append(top_k_embeddings)
+    elif len(text_list) > 1:
+        # 如果多个查询文本,依次处理
+        for query in text_list:
+            top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
+            top_k_texts_all.append(top_k_texts)
+            top_k_embeddings_all.append(top_k_embeddings)
+
+    # 将嵌入向量(ndarray)转换为可序列化的列表
+    top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
+
+    print(f"Top K texts: {top_k_texts_all}")  # 打印检索到的文本
+    print(f"Top K embeddings: {top_k_embeddings_all}")  # 打印检索到的向量
+
+    # 返回 JSON 格式的数据
+    return jsonify({
+
+        "topKEmbeddings": top_k_embeddings_all  # 返回嵌入向量
+    })
+
+if __name__ == '__main__':
+    app.run(host="0.0.0.0", port=5000, debug=True)

+ 3 - 0
script/docker/localModels/requirements.txt

@@ -0,0 +1,3 @@
+Flask==2.0.3
+sentence-transformers==2.2.0
+scikit-learn==0.24.2