Răsfoiți Sursa

Merge branch 'develop' into dev_wk2

WANGKANG 8 luni în urmă
părinte
comite
8e9b5fba41

+ 38 - 62
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataServiceImpl.java

@@ -1,5 +1,7 @@
 package com.taais.biz.service.impl;
 
+import java.util.Date;
+
 import cn.hutool.core.util.ObjectUtil;
 import cn.hutool.http.HttpRequest;
 import com.fasterxml.jackson.databind.JsonNode;
@@ -8,10 +10,11 @@ import com.mybatisflex.core.paginate.Page;
 import com.mybatisflex.core.query.QueryWrapper;
 import com.taais.biz.constant.BizConstant;
 import com.taais.biz.domain.Data;
+import com.taais.biz.domain.DataAmplificationTask;
 import com.taais.biz.domain.bo.DataAmplificationTaskBo;
 import com.taais.biz.domain.bo.DataBo;
-import com.taais.biz.domain.dto.DataAmplifyDto;
 import com.taais.biz.domain.vo.BatchDataResult;
+import com.taais.biz.domain.dto.DataAmplifyDto;
 import com.taais.biz.domain.vo.DataAmplificationTaskVo;
 import com.taais.biz.domain.vo.DataSelectVo;
 import com.taais.biz.domain.vo.DataVo;
@@ -30,6 +33,8 @@ import com.taais.common.json.utils.JsonUtils;
 import com.taais.common.orm.core.page.PageQuery;
 import com.taais.common.orm.core.service.impl.BaseServiceImpl;
 import com.taais.common.redis.utils.RedisUtils;
+import com.taais.system.domain.vo.SysDictDataVo;
+import com.taais.system.service.ISysDictDataService;
 import jakarta.annotation.Resource;
 import net.lingala.zip4j.model.FileHeader;
 import org.redisson.api.RedissonClient;
@@ -70,7 +75,6 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
     private static final String RAR = ".rar";
     private static final String TXT = ".txt";
     private static final String[] VALID_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp"};
-    private static final String PYTHON_DATA_AMPLIFY_API = "http://127.0.0.1:11001/augment";
     private static final String RESULT_CODE = "status";
     private static final String RESULT_STATUS = "200";
     private static final String AMPLIFY = "/AMPLIFY/";
@@ -80,6 +84,9 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
     @Resource
     private IDataAmplificationTaskService dataAmplificationTaskService;
 
+    @Resource
+    private ISysDictDataService dictDataService;
+
     @Override
     public QueryWrapper query() {
         return super.query().from(DATA);
@@ -152,6 +159,7 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
     }
 
     @Override
+    @Transactional
     public CommonResult<Boolean> uploadDataInfo(MultipartFile file, Data dataInfo) {
         //1.检测是否有重复的批次号
         QueryWrapper query = query();
@@ -203,6 +211,7 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
             List<Boolean> labeledList = new ArrayList<>();
             List<Data> dataList = new ArrayList<>();
             AtomicInteger countSize = new AtomicInteger();
+            List<Data> finalDataList = dataList;
             extractedImagesFileList.forEach(fileInfo -> {
                 //获取ID
                 Long id = ids.get(countSize.get());
@@ -246,7 +255,7 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
                         String labelUrl = FileUploadUtils.getPathFileName(destInfo, id + ".txt");
                         data.setLabelurl(labelUrl);
                     }
-                    dataList.add(data);
+                    finalDataList.add(data);
                 } catch (IOException e) {
                     throw new RuntimeException(e);
                 }
@@ -270,7 +279,7 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
 //                    return CommonResult.fail(format);
 //                }
 //            }
-            if (labeled){
+            if (labeled) {
                 // 未标注数量
                 long unmarkedCount = labeledList.stream().filter(Boolean.FALSE::equals).count();
                 // 已标注数量
@@ -282,7 +291,17 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
                 }
             }
             // TODO 李兆晏 确认逻辑是否正确 end
-            dataMapper.insertBatch(dataList);
+
+            int batchSize = 100;
+            int totalSize = dataList.size();
+            int page = (int) Math.ceil((double) totalSize / batchSize); // 计算总批次数
+
+            for (int i = 0; i < page; i++) {
+                int fromIndex = i * batchSize;
+                int toIndex = Math.min(fromIndex + batchSize, totalSize); // 确保最后一批不越界
+                List<Data> batchList = new ArrayList<>(dataList.subList(fromIndex, toIndex)); // 生成子列表副本
+                dataMapper.insertBatch(batchList); // 批量插入
+            }
             FileUtils.deleteFile(destZip);
         } catch (Exception e) {
             log.error("[uploadDataInfo]数据集处理出现未知异常.e:", e);
@@ -294,58 +313,15 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
     @Override
     @Transactional
     public CommonResult<Boolean> dataAmplify(DataAmplifyDto dataAmplifyDto) {
-        //根据批次号获取该批次的所有文件数据
-        QueryWrapper query = query();
-        query.eq(Data::getBatchNum, dataAmplifyDto.getBatchNum());
-        List<Data> dataList = dataMapper.selectListByQuery(query);
-        if (dataList.isEmpty()) {
-            return CommonResult.fail("该批次下没有文件数据,请重新选择批次!");
-        }
-        //TODO: 此处需要定义任务开始,把相关任务信息添加上(任务名称、任务开始时间、任务类型),然后再处理文件。
-
-        List<Data> dataListInfo = dataList.stream().filter(data -> !StringUtils.isEmpty(data.getUrl())).toList();
-        if (dataListInfo.isEmpty()) {
-            return CommonResult.fail("该批次下没有文件数据,请重新选择批次!");
-        }
-        String filePath = TaaisConfig.getUploadPath();
-        LocalDate currentDate = LocalDate.now();
-        // 定义日期格式器
-        DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy/MM/dd");
-        String formattedDate = currentDate.format(formatter);
-        filePath = filePath + File.separator + formattedDate;
-        String finalFilePath = filePath;
-        dataListInfo.forEach(dataInfo -> {
-            try {
-                //循环调用Python扩增接口
-                Map<String, Object> bodyJson = new HashMap<>();
-                bodyJson.put("augmentationType", dataAmplifyDto.getAugmentationType());
-                bodyJson.put("inputImagePath", dataInfo.getUrl());
-                String outputImagePath = finalFilePath + AMPLIFY + System.currentTimeMillis();
-                File desc = new File(outputImagePath);
-                if (!desc.exists()) {
-                    log.info("创建文件目录: {}", desc.mkdirs());
-                }
-                bodyJson.put("outputImagePath", outputImagePath);
-                bodyJson.put("otherParams", dataAmplifyDto.getOtherParams());
-                //实际请求接口,接口未提供,暂且注释
-//                String response = HttpRequest.post(PYTHON_DATA_AMPLIFY_API)
-//                        .body(JsonUtils.toJsonString(bodyJson))
-//                        .execute().body();
-                String response = "{\"status\":200,\"msg\":\"扩增成功\"}";
-                ObjectMapper objectMapper = new ObjectMapper();
-                JsonNode rootNode = objectMapper.readTree(response);
-                String resultCode = rootNode.path(RESULT_CODE).asText();
-                //判断接口是否响应成功
-                if (!RESULT_STATUS.equals(resultCode)) {
-                    throw new RuntimeException("调用Python接口返回扩增失败");
-                }
-                //处理当前目录文件,并进行入库
-                saveDataInfo(outputImagePath, dataInfo);
-            } catch (Exception e) {
-                throw new RuntimeException(e);
-            }
-        });
-        return CommonResult.fail("该批次下没有文件数据,请重新选择批次!");
+        DataAmplificationTaskBo dataAmplificationTaskBo = new DataAmplificationTaskBo();
+        dataAmplificationTaskBo.setName(dataAmplifyDto.getTaskName());
+        dataAmplificationTaskBo.setStatus(BizConstant.TASK_STATUS_PENDING);
+        dataAmplificationTaskBo.setDataBatchNums(dataAmplifyDto.getBatchNum());
+        dataAmplificationTaskBo.setAugmentationType(dataAmplifyDto.getAugmentationType());
+        dataAmplificationTaskBo.setParameters(JsonUtils.toJsonString(dataAmplifyDto.getOtherParams()));
+        dataAmplificationTaskBo.setDelFlag(0);
+        DataAmplificationTask insertedTask = dataAmplificationTaskService.insertTask(dataAmplificationTaskBo);
+        return this.amplifyForData(insertedTask.getId().toString());
     }
 
     private void initFileInfo(String dest, List<File> extractedImagesFileList, boolean directory, String fileName) {
@@ -477,6 +453,8 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
     private CommonResult<Boolean> doAmplify(DataAmplifyDto dataAmplifyDto) {
         this.updateEmp(Long.valueOf(dataAmplifyDto.getId()));
         String[] split = dataAmplifyDto.getBatchNum().split(",");
+        SysDictDataVo sysDictDataVo = dictDataService.selectDictDataByTypeAndLabel("python_api_address", "python_expand_data");
+
         for (String batchNum : split) {
             QueryWrapper query = query();
             query.eq(Data::getBatchNum, batchNum);
@@ -502,8 +480,6 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
             for (Map<String, String> param : dataAmplifyDto.getOtherParams()) {
                 otherParams.put(param.get("agName"), param.get("defaultValue"));
             }
-            Map<String, String> param = dataAmplifyDto.getOtherParams()
-                .stream().collect(Collectors.toMap(item -> item.get("agName"), item -> item.get("defaultValue")));
             dataListInfo.forEach(dataInfo -> {
                 try {
                     boolean success = true;
@@ -527,7 +503,7 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
                     System.out.println("logPath===>" + logPath);
                     bodyJson.put("logPath", logPath);
                     //实际请求接口,接口未提供,暂且注释
-                    String response = HttpRequest.post(PYTHON_DATA_AMPLIFY_API)
+                    String response = HttpRequest.post(sysDictDataVo.getDictValue())
                         .body(JsonUtils.toJsonString(bodyJson))
                         .execute().body();
 //                String response = "{\"status\":200,\"msg\":\"扩增成功\"}";
@@ -546,12 +522,12 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
                     Date endTime = new Date();
                     DataAmplificationTaskBo update = new DataAmplificationTaskBo();
                     if (taskVo.getInputImagePath() != null) {
-                        update.setInputImagePath(taskVo.getInputImagePath()+"|" + dataInfo.getUrl());
+                        update.setInputImagePath(taskVo.getInputImagePath() + "|" + dataInfo.getUrl());
                     } else {
                         update.setInputImagePath(dataInfo.getUrl());
                     }
                     if (taskVo.getOutputImagePath() != null) {
-                        update.setOutputImagePath(taskVo.getOutputImagePath()+"|" + outputImagePath);
+                        update.setOutputImagePath(taskVo.getOutputImagePath() + "|" + outputImagePath);
                     } else {
                         update.setOutputImagePath(outputImagePath);
                     }