Эх сурвалжийг харах

fix: ollama兼容联网查询 知识库检索

ageerle 2 долоо хоног өмнө
parent
commit
efeb0bd6fb

+ 1 - 1
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java

@@ -466,8 +466,8 @@ public class OpenAiStreamClient {
      * @since 1.1.3
      */
     public ResponseBody textToSpeech(TextToSpeech textToSpeech){
-        Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
         try {
+            Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
             return responseBody.execute().body();
         } catch (IOException e) {
             throw new BaseException("文本转语音(同步)失败: "+e.getMessage());

+ 5 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/request/ChatRequest.java

@@ -26,6 +26,11 @@ public class ChatRequest {
      */
     private String prompt;
 
+    /**
+     * 系统提示词
+     */
+    private String sysPrompt;
+
     /**
      * 是否开启流式对话
      */

+ 90 - 60
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java

@@ -84,77 +84,43 @@ public class SseServiceImpl implements ISseService {
 
     private final IChatCostService chatCostService;
 
-    private static final String requestIdTemplate = "mycompany-%d";
+    private static final String requestIdTemplate = "company-%d";
 
     private static final ObjectMapper mapper = new ObjectMapper();
 
-    private OpenAiStreamClient openAiModelStreamClient;
-
     @Override
     public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
         SseEmitter sseEmitter = new SseEmitter(0L);
         SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
         // 获取对话消息列表
         List<Message> messages = chatRequest.getMessages();
-        // 用户对话内容
-        String chatString = null;
         try {
-            if (StpUtil.isLogin()) {
-                // 通过模型名称查询模型信息
-                ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
-                if(chatModelVo!=null){
-                    // 通过模型信息构建请求客户端
-                    openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
-                }else {
-                    // 使用默认客户端
-                    openAiModelStreamClient  = openAiStreamClient;
-                }
+            // 查询模型信息
+            ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
+
+            OpenAiStreamClient openAiModelStreamClient;
+            if(chatModelVo!=null){
+                // 建请求客户端
+                openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
                 // 设置默认提示词
-                Message sysMessage = Message.builder().content(chatModelVo.getSystemPrompt()).role(Message.Role.SYSTEM).build();
-                messages.add(0,sysMessage);
-
-                // 查询向量库相关信息加入到上下文
-                if(chatRequest.getKid()!=null){
-                    List<Message> knMessages = new ArrayList<>();
-                    String content = messages.get(messages.size() - 1).getContent().toString();
-                    List<String> nearestList;
-                    List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
-                    nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
-                    for (String prompt : nearestList) {
-                        Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
-                        knMessages.add(userMessage);
-                    }
-                    Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
-                    knMessages.add(userMessage);
-                    messages.addAll(knMessages);
-                }
+                chatRequest.setSysPrompt(chatModelVo.getSystemPrompt());
+            }else {
+                // 使用默认客户端
+                openAiModelStreamClient = openAiStreamClient;
+            }
+            // 构建消息列表增加联网、知识库等内容
+            buildChatMessageList(chatRequest);
 
-                // 获取用户对话信息
-                Object content = messages.get(messages.size() - 1).getContent();
-                if (content instanceof List<?> listContent) {
-                    if (CollectionUtil.isNotEmpty(listContent)) {
-                        chatString = listContent.get(0).toString();
-                    }
-                } else if (content instanceof String) {
-                    chatString = (String) content;
-                }
+            // 根据模型名称前缀调用不同的处理逻辑
+            switchModelAndHandle(chatRequest);
 
-                // 加载联网信息
-                if(chatRequest.getSearch()){
-                    Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
-                    messages.add(message);
-                }
-            }else {
-                // 未登录用户限制对话次数
+            // 未登录用户限制对话次数
+            if (!StpUtil.isLogin()) {
                 String clientIp = IpUtil.getClientIp(request);
-
                 // 访客每天默认只能对话5次
                 int timeWindowInSeconds = 5;
-
                 String redisKey = "clientIp:" + clientIp;
-
                 int count = 0;
-
                 if (RedisUtils.getCacheObject(redisKey) == null) {
                     // 缓存有效时间1天
                     RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
@@ -175,6 +141,7 @@ public class SseServiceImpl implements ISseService {
                     .stream(chatRequest.getStream())
                     .build();
             openAiModelStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
+
             // 保存消息记录 并扣除费用
             chatCostService.deductToken(chatRequest);
         } catch (Exception e) {
@@ -185,6 +152,69 @@ public class SseServiceImpl implements ISseService {
         return sseEmitter;
     }
 
+    /**
+     *  根据模型名称前缀调用不同的处理逻辑
+     */
+    private void switchModelAndHandle(ChatRequest chatRequest) {
+        String model = chatRequest.getModel();
+        // 如果模型名称以ollama开头,则调用ollama中部署的本地模型
+        if (model.startsWith("ollama-")) {
+            String[] parts = chatRequest.getModel().split("ollama-", 2); // 限制分割次数为2
+            if (parts.length > 1) {
+                chatRequest.setModel(parts[1]);
+                ollamaChat(chatRequest);
+            } else {
+                throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
+            }
+        } else if (model.startsWith("gpt-4-gizmo")) {
+            chatRequest.setModel("gpt-4-gizmo");
+        }
+    }
+
+    /**
+     *  构建消息列表
+     */
+    private void buildChatMessageList(ChatRequest chatRequest){
+        // 获取对话消息列表
+        List<Message> messages = chatRequest.getMessages();
+        // 设置系统默认提示词
+        Message sysMessage = Message.builder().content(chatRequest.getSysPrompt()).role(Message.Role.SYSTEM).build();
+        messages.add(0,sysMessage);
+
+        // 查询向量库相关信息加入到上下文
+        if(chatRequest.getKid()!=null){
+            List<Message> knMessages = new ArrayList<>();
+            String content = messages.get(messages.size() - 1).getContent().toString();
+            List<String> nearestList;
+            List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
+            nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
+            for (String prompt : nearestList) {
+                Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
+                knMessages.add(userMessage);
+            }
+            Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
+            knMessages.add(userMessage);
+            messages.addAll(knMessages);
+        }
+        // 用户对话内容
+        String chatString = null;
+        // 获取用户对话信息
+        Object content = messages.get(messages.size() - 1).getContent();
+        if (content instanceof List<?> listContent) {
+            if (CollectionUtil.isNotEmpty(listContent)) {
+                chatString = listContent.get(0).toString();
+            }
+        } else if (content instanceof String) {
+            chatString = (String) content;
+        }
+        // 设置对话信息
+        chatRequest.setPrompt(chatString);
+        // 加载联网信息
+        if(chatRequest.getSearch()){
+            Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
+            messages.add(message);
+        }
+    }
 
     /**
      * 发送SSE错误事件的封装方法
@@ -295,13 +325,13 @@ public class SseServiceImpl implements ISseService {
 
     @Override
     public SseEmitter ollamaChat(ChatRequest chatRequest) {
-        String[] parts = chatRequest.getModel().split("ollama-");
+
         ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
         final SseEmitter emitter = new SseEmitter();
         String host = chatModelVo.getApiHost();
         List<Message> msgList = chatRequest.getMessages();
-        List<OllamaChatMessage> messages = new ArrayList<>();
 
+        List<OllamaChatMessage> messages = new ArrayList<>();
         for (Message message : msgList) {
             OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
             ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
@@ -310,7 +340,7 @@ public class SseServiceImpl implements ISseService {
         }
         OllamaAPI api = new OllamaAPI(host);
         api.setRequestTimeoutSeconds(100);
-        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(parts[1]);
+        OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatRequest.getModel());
 
         OllamaChatRequestModel requestModel = builder
             .withMessages(messages)
@@ -356,11 +386,11 @@ public class SseServiceImpl implements ISseService {
 
     @Override
     public String webSearch (String prompt) {
-        String zhipuValue = configService.getConfigValue("zhipu", "key");
-        if(StringUtils.isEmpty(zhipuValue)){
-            throw new IllegalStateException("zhipu config value is empty,请在chat_config中配置zhipu key信息");
+        String zpValue = configService.getConfigValue("zhipu", "key");
+        if(StringUtils.isEmpty(zpValue)){
+            throw new IllegalStateException("请在chat_config中配置智谱key信息");
         }else {
-            ClientV4 client = new ClientV4.Builder(zhipuValue)
+            ClientV4 client = new ClientV4.Builder(zpValue)
                     .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
                     .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
                     .build();