|
@@ -9,6 +9,7 @@ package org.eco.vip.ai.text2sql.service;
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.util.ObjUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
+import cn.hutool.extra.spring.SpringUtil;
|
|
|
import cn.hutool.json.JSONArray;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
@@ -18,6 +19,7 @@ import org.eco.vip.ai.text2sql.utils.SqlValidator;
|
|
|
import org.eco.vip.orm.utils.DBaseHelper;
|
|
|
import org.eco.vip.text2sql.domain.ContentVo;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
+import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
import java.util.Collections;
|
|
|
import java.util.LinkedHashMap;
|
|
@@ -39,16 +41,19 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
private TableNameService tableNameService;
|
|
|
|
|
|
@Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
public List<Map<String, Object>> executeSql(String sql) {
|
|
|
return text2SqlMapper.executeSql(sql);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
public String tableNameByNl(String question) {
|
|
|
return tableNameService.extractTableName(question);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
public List<CommentVo> commentByTableName(String schema, String tableName) {
|
|
|
if (DBaseHelper.isMySql()) {
|
|
|
return text2SqlMapper.selectCommentsByMysql(schema, tableName);
|
|
@@ -59,6 +64,7 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
public JSONArray getCommentData(List<Map<String, Object>> mapList, List<CommentVo> tableNames) {
|
|
|
JSONArray jsonArray = new JSONArray();
|
|
|
Map<String, Object> commentMap = new LinkedHashMap<>();
|
|
@@ -80,20 +86,22 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
public JSONArray getAnswer(ContentVo question) {
|
|
|
log.info("\n模型返回的SQL:\n{}", question.getContent());
|
|
|
String sql = cleanRawSql(question.getContent());
|
|
|
if (!SqlValidator.isSelectQuery(sql)) {
|
|
|
return new JSONArray().put("只支持查询");
|
|
|
}
|
|
|
- List<Map<String, Object>> mapList = text2SqlMapper.executeSql(sql);
|
|
|
- List<CommentVo> tableNames = this.commentByTableName(DBaseHelper.getSchema(), question.getTableName());
|
|
|
- return this.getCommentData(mapList, tableNames);
|
|
|
+ List<Map<String, Object>> mapList = SpringUtil.getBean(Text2SqlService.class).executeSql(sql);
|
|
|
+ List<CommentVo> tableNames = SpringUtil.getBean(Text2SqlService.class).commentByTableName(DBaseHelper.getSchema(), question.getTableName());
|
|
|
+ return SpringUtil.getBean(Text2SqlService.class).getCommentData(mapList, tableNames);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
public ContentVo getQuestion(String question) {
|
|
|
- String tableName = this.tableNameByNl(question);
|
|
|
+ String tableName = SpringUtil.getBean(Text2SqlService.class).tableNameByNl(question);
|
|
|
Map<String, String> map = Collections.emptyMap();
|
|
|
if (DBaseHelper.isMySql()) {
|
|
|
map = text2SqlMapper.selectTableDdlByMysql(tableName);
|
|
@@ -101,10 +109,13 @@ public class Text2SqlService implements IText2SqlService {
|
|
|
map = text2SqlMapper.selectTableDdlDm(DBaseHelper.getSchema(), tableName);
|
|
|
}
|
|
|
String ddl = map.get("create table");
|
|
|
- return ContentVo.builder().content(ddl + "\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
|
|
|
- "1.不要需求外的条件.\n" +
|
|
|
- "2.只输出sql语句不需要任何格式样式,就是一串sql.\n" +
|
|
|
- "3.不需要带Schema:" + DBaseHelper.getSchema()).
|
|
|
+ question= ddl + "\n" + question + ",根据需求生成" + DBaseHelper.getDbType() + "的查询SQL.\n" +
|
|
|
+ "1.不要需求外的条件.\n" +
|
|
|
+ "2.只输出sql语句不需要任何格式样式,就是一串sql.\n" +
|
|
|
+ "3.字段严格按照ddl表结构里面的给出.\n" +
|
|
|
+ "4.不需要带Schema:" + DBaseHelper.getSchema();
|
|
|
+ log.info("\nquestion ____________________:\n{}", question);
|
|
|
+ return ContentVo.builder().content(question).
|
|
|
tableName(tableName).build();
|
|
|
|
|
|
}
|