|
@@ -1,122 +0,0 @@
|
|
|
-/*
|
|
|
- * Copyright (c) 2025 GaoKunW
|
|
|
- *
|
|
|
- */
|
|
|
-
|
|
|
-package org.eco.vip.ai.text2sql.service;
|
|
|
-
|
|
|
-
|
|
|
-import cn.hutool.core.collection.CollUtil;
|
|
|
-import cn.hutool.core.util.ObjUtil;
|
|
|
-import cn.hutool.core.util.StrUtil;
|
|
|
-import cn.hutool.json.JSONArray;
|
|
|
-import jakarta.annotation.Resource;
|
|
|
-import lombok.extern.slf4j.Slf4j;
|
|
|
-import org.eco.vip.ai.text2sql.domain.CommentVo;
|
|
|
-import org.eco.vip.ai.text2sql.mapper.Text2SqlMapper;
|
|
|
-import org.eco.vip.ai.text2sql.utils.SqlValidator;
|
|
|
-import org.eco.vip.orm.utils.DBaseHelper;
|
|
|
-import org.eco.vip.text2sql.domain.ContentVo;
|
|
|
-import org.springframework.stereotype.Service;
|
|
|
-
|
|
|
-import java.util.Collections;
|
|
|
-import java.util.LinkedHashMap;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
-
|
|
|
-/**
|
|
|
- * @author GaoKunW
|
|
|
- * @description Text2SqlService
|
|
|
- * @date 2025/3/12 10:54
|
|
|
- */
|
|
|
-@Slf4j
|
|
|
-@Service
|
|
|
-public class Text2SqlService implements IText2SqlService {
|
|
|
- @Resource
|
|
|
- private Text2SqlMapper text2SqlMapper;
|
|
|
-
|
|
|
- @Resource
|
|
|
- private TableNameService tableNameService;
|
|
|
-
|
|
|
- @Override
|
|
|
- public List<Map<String, Object>> executeSql(String sql) {
|
|
|
- return text2SqlMapper.executeSql(sql);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String tableNameByNl(String question) {
|
|
|
- return tableNameService.extractTableName(question);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public List<CommentVo> commentByTableName(String schema, String tableName) {
|
|
|
- if (DBaseHelper.isMySql()) {
|
|
|
- return text2SqlMapper.selectCommentsByMysql(schema, tableName);
|
|
|
- } else if (DBaseHelper.isDmSql()) {
|
|
|
- return text2SqlMapper.selectCommentsByDm(schema, tableName);
|
|
|
- }
|
|
|
- return Collections.emptyList();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public JSONArray getCommentData(List<Map<String, Object>> mapList, List<CommentVo> tableNames) {
|
|
|
- JSONArray jsonArray = new JSONArray();
|
|
|
- Map<String, Object> commentMap = new LinkedHashMap<>();
|
|
|
- for (Map<String, Object> map : mapList) {
|
|
|
- for (String key : map.keySet()) {
|
|
|
- if (StrUtil.contains("tenantId,delFlag,status,createBy,updateBy,createTime,updateTime", key)) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- CommentVo commentVo = CollUtil.findOne(tableNames, tableName -> StrUtil.equals(StrUtil.toCamelCase(tableName.getName()), key));
|
|
|
- if (ObjUtil.isNotEmpty(commentVo)) {
|
|
|
- commentMap.put(commentVo.getComment(), map.get(key));
|
|
|
- } else {
|
|
|
- commentMap.put(key, map.get(key));
|
|
|
- }
|
|
|
- }
|
|
|
- jsonArray.add(commentMap);
|
|
|
- }
|
|
|
- return jsonArray;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public JSONArray getAnswer(ContentVo question) {
|
|
|
- log.info("\n模型返回的SQL:\n{}", question.getContent());
|
|
|
- String sql = cleanRawSql(question.getContent());
|
|
|
- if (!SqlValidator.isSelectQuery(sql)) {
|
|
|
- return new JSONArray().put("只支持查询");
|
|
|
- }
|
|
|
- List<Map<String, Object>> mapList = text2SqlMapper.executeSql(sql);
|
|
|
- List<CommentVo> tableNames = this.commentByTableName(DBaseHelper.getSchema(), question.getTableName());
|
|
|
- return this.getCommentData(mapList, tableNames);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public ContentVo getQuestion(String question) {
|
|
|
- String tableName = this.tableNameByNl(question);
|
|
|
- Map<String, String> map = Collections.emptyMap();
|
|
|
- if (DBaseHelper.isMySql()) {
|
|
|
- map = text2SqlMapper.selectTableDdlByMysql(tableName);
|
|
|
- } else if (DBaseHelper.isDmSql()) {
|
|
|
- map = text2SqlMapper.selectTableDdlDm(DBaseHelper.getSchema(), tableName);
|
|
|
- }
|
|
|
- String ddl = map.get("create table");
|
|
|
- return ContentVo.builder().content(ddl + "\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
|
|
|
- "1.不要需求外的条件.\n" +
|
|
|
- "2.只输出sql语句不需要任何格式样式,就是一串sql.\n" +
|
|
|
- "3.不需要带Schema:" + DBaseHelper.getSchema()).
|
|
|
- tableName(tableName).build();
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- private String cleanRawSql(String rawSql) {
|
|
|
- // 1. 去除代码块标记(如 ```sql)和语言标识
|
|
|
- String cleaned = rawSql.replaceAll("(?i).*```sql\\s*", "")
|
|
|
- .replaceAll("\\s*```.*", "")
|
|
|
- .trim();
|
|
|
-
|
|
|
- // 2. 去除末尾分号(MySQL允许不带分号执行)
|
|
|
- cleaned = cleaned.replaceAll(";\\s*$", "");
|
|
|
- return cleaned;
|
|
|
- }
|
|
|
-}
|