|
@@ -10,13 +10,17 @@ import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.hutool.json.JSONArray;
|
|
|
import jakarta.annotation.Resource;
|
|
|
+import org.eco.vip.ai.ollama.api.api.IOllamaApi;
|
|
|
import org.eco.vip.ai.text2sql.domain.CommentVo;
|
|
|
import org.eco.vip.ai.text2sql.mapper.Text2SqlMapper;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
import java.util.HashMap;
|
|
|
+import java.util.LinkedHashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import java.util.regex.Matcher;
|
|
|
+import java.util.regex.Pattern;
|
|
|
|
|
|
/**
|
|
|
* @description Text2SqlService
|
|
@@ -29,14 +33,17 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
@Resource
|
|
|
private Text2SqlMapper text2SqlMapper;
|
|
|
|
|
|
+ @Resource
|
|
|
+ private IOllamaApi ollamaApi;
|
|
|
+
|
|
|
@Override
|
|
|
public List<Map<String, Object>> executeSql(String sql) {
|
|
|
return text2SqlMapper.executeSql(sql);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public List<Map<String, Object>> tableNameByNl(String sql) {
|
|
|
- return text2SqlMapper.executeSql(sql);
|
|
|
+ public String tableNameByNl(String question) {
|
|
|
+ return this.extractTableName(question);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -45,9 +52,7 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public JSONArray getCommentData(String question) {
|
|
|
- List<Map<String, Object>> mapList = text2SqlMapper.executeSql("select * from sys_user");
|
|
|
- List<CommentVo> tableNames = text2SqlMapper.selectCommentsByMysql("eco-boot", "sys_user");
|
|
|
+ public JSONArray getCommentData(List<Map<String, Object>> mapList, List<CommentVo> tableNames) {
|
|
|
JSONArray jsonArray = new JSONArray();
|
|
|
Map<String, Object> commentMap = new HashMap<>();
|
|
|
for (Map<String, Object> map : mapList) {
|
|
@@ -59,4 +64,35 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
}
|
|
|
return jsonArray;
|
|
|
}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public JSONArray getAnswer(String question) {
|
|
|
+ String tableName = this.tableNameByNl(question);
|
|
|
+ Map<String, String> map = text2SqlMapper.selectTableDdl(tableName);
|
|
|
+ String ddl = map.get("create table");
|
|
|
+ String content = ollamaApi.syncChat(ddl + "\n" + question + ",根据需求生成Mysql的查询SQL,不要需求外的条件,只输出sql语句不需要任何格式样式,就是一串sql");
|
|
|
+ List<Map<String, Object>> mapList = text2SqlMapper.executeSql(content);
|
|
|
+ List<CommentVo> tableNames = text2SqlMapper.selectCommentsByMysql("eco-boot", tableName);
|
|
|
+ return this.getCommentData(mapList, tableNames);
|
|
|
+ }
|
|
|
+
|
|
|
+ private String extractTableName(String query) {
|
|
|
+ // 使用LinkedHashMap保持插入顺序,确保匹配优先级
|
|
|
+ Map<String, String> keywordMapping = new LinkedHashMap<>();
|
|
|
+ keywordMapping.put("部门表|部门信息", "sys_dept");
|
|
|
+ keywordMapping.put("用户|人员|姓名|账号|性别", "sys_user");
|
|
|
+
|
|
|
+ // 预处理:去除标点符号和非文字字符
|
|
|
+ String cleanedQuery = query.replaceAll("[^\\w\\u4e00-\\u9fff]", "");
|
|
|
+
|
|
|
+ // 遍历映射进行正则匹配
|
|
|
+ for (Map.Entry<String, String> entry : keywordMapping.entrySet()) {
|
|
|
+ Pattern pattern = Pattern.compile(entry.getKey());
|
|
|
+ Matcher matcher = pattern.matcher(cleanedQuery);
|
|
|
+ if (matcher.find()) {
|
|
|
+ return entry.getValue();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return "unknown_table";
|
|
|
+ }
|
|
|
}
|