Преглед на файлове

大模型提示词和SQL格式化

Gaokun Wang преди 4 месеца
родител
ревизия
002d641cfc

+ 7 - 27
eco-ai/ai-text-sql-biz/src/main/java/org/eco/vip/ai/text2sql/service/Text2SqlService.java

@@ -15,7 +15,7 @@ import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.eco.vip.ai.text2sql.domain.CommentVo;
 import org.eco.vip.ai.text2sql.mapper.Text2SqlMapper;
-import org.eco.vip.ai.text2sql.utils.SqlValidator;
+import org.eco.vip.ai.text2sql.utils.SqlUtils;
 import org.eco.vip.orm.utils.DBaseHelper;
 import org.eco.vip.text2sql.domain.ContentVo;
 import org.springframework.stereotype.Service;
@@ -89,8 +89,8 @@ public class Text2SqlService implements IText2SqlService {
     @Transactional(rollbackFor = Exception.class)
     public JSONArray getAnswer(ContentVo question) {
         log.info("\n模型返回的SQL:\n{}", question.getContent());
-        String sql = cleanRawSql(question.getContent());
-        if (!SqlValidator.isSelectQuery(sql)) {
+        String sql = SqlUtils.extractSql(question.getContent());
+        if (!SqlUtils.isSelectQuery(sql)) {
             return null;
         }
         List<Map<String, Object>> mapList;
@@ -113,6 +113,9 @@ public class Text2SqlService implements IText2SqlService {
     @Transactional(rollbackFor = Exception.class)
     public ContentVo getQuestion(String question) {
         String tableName = SpringUtil.getBean(Text2SqlService.class).tableNameByNl(question);
+        if (StrUtil.isEmpty(tableName)) {
+            return null;
+        }
         Map<String, String> map = Collections.emptyMap();
         if (DBaseHelper.isMySql()) {
             map = text2SqlMapper.selectTableDdlByMysql(tableName);
@@ -123,35 +126,12 @@ public class Text2SqlService implements IText2SqlService {
         question = "你是一个专业的SQL生成助手,请严格按照以下提供的表DDL生成查询语句。\nDLL:\n" + ddl + "\n条件:\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
                 "1.解析用户提供的表DDL,提取所有字段名称;\n" +
                 "2.生成的查询语句必须仅包含这些字段,不可添加其他字段;\n" +
-                "3. 如果用户未明确指定查询逻辑,默认生成基础SELECT语句.\n" +
+                "3.如果用户未明确指定查询逻辑,默认生成基础SELECT语句.输出 SELECT * FROM " + tableName + "\n" +
                 "4.不需要带Schema:" + DBaseHelper.getSchema() + "\n" +
                 "5.只输出SQL";
-//        question= ddl + "\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
-//                "1.生成的查询语句必须仅包含这些字段,不可添加其他字段,from 后面的条件不要自己发挥\n" +
-//                "2.只输出sql语句不需要任何格式样式,就是一串sql.\n" +
-//                "3.不需要带Schema:" + DBaseHelper.getSchema();
         log.info("\nquestion ____________________:\n{}", question);
         return ContentVo.builder().content(question).
                 tableName(tableName).build();
 
     }
-
-    private String cleanRawSql(String rawSql) {
-//        // 1. 去除代码块标记(如 ```sql)和语言标识
-        String cleaned = rawSql.replaceAll("(?i).*```sql\\s*", "")
-                .replaceAll("\\s*```.*", "")
-                .trim();
-//        String cleaned = rawSql.trim();
-//        if (rawSql.contains("```")) {
-//            cleaned = rawSql.replaceAll("(?i).*```sql\\s*", "")
-//                    .replaceAll("\\s*```.*", "")
-//                    .trim();
-//            Pattern pattern = Pattern.compile("```([\\s\\S]*?)```");
-//            Matcher matcher = pattern.matcher(cleaned);
-//            cleaned = matcher.group(1).trim();
-//        }
-        // 2. 去除末尾分号(MySQL允许不带分号执行)
-        cleaned = cleaned.replaceAll(";\\s*$", "");
-        return cleaned;
-    }
 }

+ 57 - 0
eco-ai/ai-text-sql-biz/src/main/java/org/eco/vip/ai/text2sql/utils/SqlUtils.java

@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2025 GaoKunW
+ *
+ */
+
+package org.eco.vip.ai.text2sql.utils;
+
+
+import lombok.extern.slf4j.Slf4j;
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.Statement;
+import net.sf.jsqlparser.statement.select.Select;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * @author GaoKunW
+ * @description SqlValidatorAdvanced
+ * @date 2025/3/13 10:34
+ */
+@Slf4j
+public class SqlUtils {
+    public static boolean isSelectQuery(String sql) {
+        try {
+            Statement statement = CCJSqlParserUtil.parse(sql);
+            return statement instanceof Select;
+        } catch (JSQLParserException e) {
+            log.error("SQL解析异常:{}", e.getMessage());
+            return false;
+        }
+    }
+
+    public static String extractSql(String rawSql) {
+        // 匹配带反引号的代码块(支持可选的sql标识符)
+        Pattern codeBlockPattern = Pattern.compile("(?s)```(?:sql)?\\s*(.*?)\\s*```");
+        // 直接匹配SELECT语句(跨行、不区分大小写)
+        Pattern sqlPattern = Pattern.compile("(?si)(SELECT.*?;)");
+        String sql = null;
+        Matcher codeMatcher = codeBlockPattern.matcher(rawSql);
+        // 优先提取代码块内容
+        if (codeMatcher.find()) {
+            sql = codeMatcher.group(1).trim();
+        } else { // 若无代码块则直接匹配SQL结构
+            Matcher sqlMatcher = sqlPattern.matcher(rawSql);
+            if (sqlMatcher.find()) {
+                sql = sqlMatcher.group(1).trim();
+            }
+        }
+        if (sql != null) {
+            sql = sql.replaceAll(";\\s*$", "");
+        }
+        log.info("\n-----Formatted SQL: {}\n-----", sql);
+        return sql;
+    }
+}

+ 0 - 31
eco-ai/ai-text-sql-biz/src/main/java/org/eco/vip/ai/text2sql/utils/SqlValidator.java

@@ -1,31 +0,0 @@
-/*
- * Copyright (c) 2025 GaoKunW
- *
- */
-
-package org.eco.vip.ai.text2sql.utils;
-
-
-import lombok.extern.slf4j.Slf4j;
-import net.sf.jsqlparser.JSQLParserException;
-import net.sf.jsqlparser.parser.CCJSqlParserUtil;
-import net.sf.jsqlparser.statement.Statement;
-import net.sf.jsqlparser.statement.select.Select;
-
-/**
- * @author GaoKunW
- * @description SqlValidatorAdvanced
- * @date 2025/3/13 10:34
- */
-@Slf4j
-public class SqlValidator {
-    public static boolean isSelectQuery(String sql) {
-        try {
-            Statement statement = CCJSqlParserUtil.parse(sql);
-            return statement instanceof Select;
-        } catch (JSQLParserException e) {
-            log.error("SQL解析异常:{}", e.getMessage());
-            return false;
-        }
-    }
-}

+ 43 - 0
eco-start/src/main/resources/application-dev.yml

@@ -0,0 +1,43 @@
+--- # 数据源配置
+spring:
+  datasource:
+    type: com.zaxxer.hikari.HikariDataSource
+mybatis-flex:
+  # sql审计
+  audit_enable: true
+  # sql打印
+  sql_print: true
+  # 数据源
+  datasource:
+    # 数据源1
+    ds1:
+      type: ${spring.datasource.type}
+      # MySql
+      #      driver-class-name: com.mysql.cj.jdbc.Driver
+      #      url: jdbc:mysql://localhost:3306/eco-boot?useUnicode=true&characterEncoding=utf8&zeroDateTimeBehavior=convertToNull&useSSL=true&serverTimezone=GMT%2B8&autoReconnect=true&rewriteBatchedStatements=true&allowPublicKeyRetrieval=true
+      #      username: root
+      #      password: root123
+      #DM8数据库
+      driver-class-name: dm.jdbc.driver.DmDriver
+      url: jdbc:dm://127.0.0.1:5236?schema=lqbz&useUnicode=true&characterEncoding=utf8&useSSL=true&autoReconnect=true&reWriteBatchedInserts=true
+      username: SYSDBA
+      password: SYSDBA123
+      # 最大连接池数量
+      maximumPoolSize: 50
+      # 最小空闲线程数量
+      minimumIdle: 10
+      # 配置获取连接等待超时的时间
+      connectionTimeout: 60000
+      # 校验超时时间
+      validationTimeout: 5000
+      # 空闲连接存活最大时间,默认10分钟
+      idleTimeout: 600000
+      # 此属性控制池中连接的最长生命周期,值0表示无限生命周期,默认30分钟
+      maxLifetime: 1800000
+      # 多久检查一次连接的活性
+      keepaliveTime: 30000
+    # 数据源2
+#    ds2:
+#      url: jdbc:mysql://127.0.0.1:3306/eco1
+#      username: root
+#      password: root123

+ 3 - 3
eco-start/src/main/resources/application-local.yml

@@ -21,11 +21,11 @@ mybatis-flex:
       driver-class-name: dm.jdbc.driver.DmDriver
       url: jdbc:dm://127.0.0.1:5236?schema=lqbz&useUnicode=true&characterEncoding=utf8&useSSL=true&autoReconnect=true&reWriteBatchedInserts=true
       username: SYSDBA
-      password: SYSdba123
+      password: SYSDBA123
       # 最大连接池数量
-      maximum-pool-size: 50
+      maximumPoolSize: 50
       # 最小空闲线程数量
-      minimum-idle: 10
+      minimumIdle: 10
       # 配置获取连接等待超时的时间
       connectionTimeout: 60000
       # 校验超时时间

+ 1 - 1
eco-start/src/main/resources/logback-spring.xml

@@ -7,7 +7,7 @@
     <include resource="org/springframework/boot/logging/logback/defaults.xml"/>
 
     <!--定义日志存放的位置,默认存放在项目启动的相对路径的目录-->
-    <springProperty scope="context" name="LOG_PATH" source="log.path" defaultValue="./logs"/>
+    <springProperty scope="context" name="LOG_PATH" source="log.path" defaultValue="logs"/>
 
     <!-- ****************************** 本地开发只在控制台打印日志 ************************************ -->
     <springProfile name="local">