|
@@ -15,7 +15,7 @@ import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.eco.vip.ai.text2sql.domain.CommentVo;
|
|
|
import org.eco.vip.ai.text2sql.mapper.Text2SqlMapper;
|
|
|
-import org.eco.vip.ai.text2sql.utils.SqlValidator;
|
|
|
+import org.eco.vip.ai.text2sql.utils.SqlUtils;
|
|
|
import org.eco.vip.orm.utils.DBaseHelper;
|
|
|
import org.eco.vip.text2sql.domain.ContentVo;
|
|
|
import org.springframework.stereotype.Service;
|
|
@@ -89,8 +89,8 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public JSONArray getAnswer(ContentVo question) {
|
|
|
log.info("\n模型返回的SQL:\n{}", question.getContent());
|
|
|
- String sql = cleanRawSql(question.getContent());
|
|
|
- if (!SqlValidator.isSelectQuery(sql)) {
|
|
|
+ String sql = SqlUtils.extractSql(question.getContent());
|
|
|
+ if (!SqlUtils.isSelectQuery(sql)) {
|
|
|
return null;
|
|
|
}
|
|
|
List<Map<String, Object>> mapList;
|
|
@@ -113,6 +113,9 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public ContentVo getQuestion(String question) {
|
|
|
String tableName = SpringUtil.getBean(Text2SqlService.class).tableNameByNl(question);
|
|
|
+ if (StrUtil.isEmpty(tableName)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
Map<String, String> map = Collections.emptyMap();
|
|
|
if (DBaseHelper.isMySql()) {
|
|
|
map = text2SqlMapper.selectTableDdlByMysql(tableName);
|
|
@@ -123,35 +126,12 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
question = "你是一个专业的SQL生成助手,请严格按照以下提供的表DDL生成查询语句。\nDLL:\n" + ddl + "\n条件:\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
|
|
|
"1.解析用户提供的表DDL,提取所有字段名称;\n" +
|
|
|
"2.生成的查询语句必须仅包含这些字段,不可添加其他字段;\n" +
|
|
|
- "3. 如果用户未明确指定查询逻辑,默认生成基础SELECT语句.\n" +
|
|
|
+ "3.如果用户未明确指定查询逻辑,默认生成基础SELECT语句.输出 SELECT * FROM " + tableName + "\n" +
|
|
|
"4.不需要带Schema:" + DBaseHelper.getSchema() + "\n" +
|
|
|
"5.只输出SQL";
|
|
|
-// question= ddl + "\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
|
|
|
-// "1.生成的查询语句必须仅包含这些字段,不可添加其他字段,from 后面的条件不要自己发挥\n" +
|
|
|
-// "2.只输出sql语句不需要任何格式样式,就是一串sql.\n" +
|
|
|
-// "3.不需要带Schema:" + DBaseHelper.getSchema();
|
|
|
log.info("\nquestion ____________________:\n{}", question);
|
|
|
return ContentVo.builder().content(question).
|
|
|
tableName(tableName).build();
|
|
|
|
|
|
}
|
|
|
-
|
|
|
- private String cleanRawSql(String rawSql) {
|
|
|
-// // 1. 去除代码块标记(如 ```sql)和语言标识
|
|
|
- String cleaned = rawSql.replaceAll("(?i).*```sql\\s*", "")
|
|
|
- .replaceAll("\\s*```.*", "")
|
|
|
- .trim();
|
|
|
-// String cleaned = rawSql.trim();
|
|
|
-// if (rawSql.contains("```")) {
|
|
|
-// cleaned = rawSql.replaceAll("(?i).*```sql\\s*", "")
|
|
|
-// .replaceAll("\\s*```.*", "")
|
|
|
-// .trim();
|
|
|
-// Pattern pattern = Pattern.compile("```([\\s\\S]*?)```");
|
|
|
-// Matcher matcher = pattern.matcher(cleaned);
|
|
|
-// cleaned = matcher.group(1).trim();
|
|
|
-// }
|
|
|
- // 2. 去除末尾分号(MySQL允许不带分号执行)
|
|
|
- cleaned = cleaned.replaceAll(";\\s*$", "");
|
|
|
- return cleaned;
|
|
|
- }
|
|
|
}
|