Jelajahi Sumber

fix: 数据扩增入库

Eagle 2 minggu lalu
induk
melakukan
35000faa69

+ 14 - 6
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/PublicController.java

@@ -7,10 +7,7 @@ import com.taais.biz.domain.TaskTrackResultBo;
 import com.taais.biz.domain.bo.*;
 import com.taais.biz.domain.dto.TaskResultDTO;
 import com.taais.biz.service.*;
-import com.taais.biz.service.impl.AlgorithmModelServiceImpl;
-import com.taais.biz.service.impl.DataAmplificationTaskServiceImpl;
-import com.taais.biz.service.impl.ObjectTraceMergeServiceImpl;
-import com.taais.biz.service.impl.TargetIdentificationSubtaskDetailsServiceImpl;
+import com.taais.biz.service.impl.*;
 import com.taais.biz.service.service.impl.ObjectMatchServiceImpl;
 import com.taais.common.core.config.TaaisConfig;
 import com.taais.common.core.core.domain.CommonResult;
@@ -83,6 +80,9 @@ public class PublicController extends BaseController {
     @Resource
     DataAmplificationTaskServiceImpl dataAmplificationTaskService;
 
+    @Resource
+    DataServiceImpl dataService;
+
     @PostMapping("/taskResult")
     public CommonResult<Void> taskResult(@RequestBody TaskResultDTO resultDTO) {
         log.info("taskResult start,params:{}", resultDTO);
@@ -145,13 +145,17 @@ public class PublicController extends BaseController {
             errorMsg = dataProcessService.taskResult(resultDTO);
         } else if (BizConstant.TYPE_DATA_EXPAND.equals(bizType)) {
             DataAmplificationTaskBo bo = dataAmplificationTaskService.getById(resultDTO.getBizId());
+            log.info("data_expand callback: {}, {}", resultDTO, bo);
             if (bo != null) {
                 try {
-                    bo.setInputImagePath(String.valueOf(Integer.parseInt(bo.getOutputImagePath()) - 1));
+                    bo.setInputImagePath(String.valueOf(Integer.parseInt(bo.getInputImagePath()) - 1));
                     if (bo.getInputImagePath().equals("0")) {
                         bo.setEndTime(new Date());
                         bo.setCostSecond((int) ((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000));
                         bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
+                        if ("PERSISTENCE".equals(bo.getRemarks())) {
+                            dataService.registerExistedData(bo.getOutputImagePath() + "/result", bo);
+                        }
                     }
                 } catch (Exception e) {
                     log.error("error: {}", e.getMessage());
@@ -180,7 +184,11 @@ public class PublicController extends BaseController {
             return "bizId 不能为null";
         }
         String bizType = resultDTO.getBizType();
-        if(!BizConstant.TYPE_OBJ_MATCH.equals(bizType) && !BizConstant.TYPE_OBJ_TRACE.equals(bizType) && !BizConstant.TYPE_DATA_BIZ_PROCESS.equals(bizType) && !BizConstant.TYPE_DATA_PROCESS.equals(bizType)){
+        if(!BizConstant.TYPE_OBJ_MATCH.equals(bizType) &&
+            !BizConstant.TYPE_OBJ_TRACE.equals(bizType) &&
+            !BizConstant.TYPE_DATA_BIZ_PROCESS.equals(bizType) &&
+            !BizConstant.TYPE_DATA_PROCESS.equals(bizType) &&
+            !BizConstant.TYPE_DATA_EXPAND.equals(bizType)){
             return "status 只能是"+BizConstant.TYPE_DATA_BIZ_PROCESS+"或"+BizConstant.TYPE_DATA_PROCESS;
         }
         return null;

+ 3 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/TargetIdentificationSubtaskDetailsController.java

@@ -148,7 +148,7 @@ public class TargetIdentificationSubtaskDetailsController extends BaseController
             String url = algorithmConfigService.getByAlgorithmName("多目标检测").getStartApi();
             HttpUtil.post(url, JSONUtil.toJsonStr(params));
 
-            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(params.get("bizId"));
+            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(Long.valueOf(params.get("bizId")));
             details.setStatus(BizConstant.TASK_STATUS_PROCESSING);
             details.setStartTime(new Date());
             details.setEndTime(null);
@@ -173,7 +173,7 @@ public class TargetIdentificationSubtaskDetailsController extends BaseController
             String url = algorithmConfigService.getByAlgorithmName("多目标检测").getPauseApi();
             HttpUtil.post(url, JSONUtil.toJsonStr(params));
 
-            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(params.get("bizId"));
+            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(Long.valueOf(params.get("bizId")));
             details.setStatus(BizConstant.TASK_STATUS_PENDING);
             details.setEndTime(new Date());
             targetIdentificationSubtaskDetailsService.updateById(details);
@@ -197,7 +197,7 @@ public class TargetIdentificationSubtaskDetailsController extends BaseController
             String url = algorithmConfigService.getByAlgorithmName("多目标检测").getTerminateApi();
             HttpUtil.post(url, JSONUtil.toJsonStr(params));
 
-            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(params.get("bizId"));
+            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(Long.valueOf(params.get("bizId")));
             details.setStatus(BizConstant.TASK_STATUS_PENDING);
             details.setEndTime(new Date());
             targetIdentificationSubtaskDetailsService.updateById(details);

+ 4 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataAmplificationTaskServiceImpl.java

@@ -46,6 +46,10 @@ public class DataAmplificationTaskServiceImpl extends BaseServiceImpl<DataAmplif
         return queryWrapper;
     }
 
+    public DataAmplificationTaskBo getByNameAndSubTaskId(String name, Long subTaskId) {
+        return this.getOneAs(query().where(DATA_AMPLIFICATION_TASK.NAME.eq(name)).where(DATA_AMPLIFICATION_TASK.SUB_TASK_ID.eq(subTaskId)), DataAmplificationTaskBo.class);
+    }
+
     @Override
     public DataAmplificationTaskVo selectById(Long id) {
         return this.getOneAs(query().where(DATA_AMPLIFICATION_TASK.ID.eq(id)), DataAmplificationTaskVo.class);

+ 153 - 9
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataServiceImpl.java

@@ -315,18 +315,155 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
         return CommonResult.success("数据集上传成功!");
     }
 
+    public CommonResult<Boolean> registerExistedData(String path, DataAmplificationTaskBo bo) throws IOException {
+        List<File> fileList = Arrays.asList(new File(path + "/images").listFiles());
+        copyFiles(path + "/labels", path + "/images");
+        Data d = new Data();
+
+        d.setLabeled(true);
+        d.setBatchNum(bo.getName());
+        d.setIncrement(bo.getAugmentationType());
+        return uploadExistedData(fileList, d);
+    }
+
+    private void copyFiles(String from, String to) throws IOException {
+        for (File file : new File(from).listFiles()) {
+            File dist = new File(file.getAbsolutePath().replace("labels", "images"));
+            org.apache.commons.io.FileUtils.copyFile(file, dist);
+        }
+    }
+
+    public CommonResult<Boolean> uploadExistedData(List<File> extractedImagesFileList, Data dataInfo) {
+        try {
+            Boolean labeled = dataInfo.getLabeled();
+            //检测图片文件是否为空
+            if (extractedImagesFileList.isEmpty()) {
+                return CommonResult.fail("压缩文件图片为空,请检查后重新上传!");
+            }
+
+            //3获取ID集合
+            List<Long> ids = dataMapper.getIds(extractedImagesFileList.size());
+            if (ids.isEmpty()) {
+                return CommonResult.fail("系统异常!");
+            }
+
+            List<Boolean> labeledList = new ArrayList<>();
+            List<Data> dataList = new ArrayList<>();
+            AtomicInteger countSize = new AtomicInteger();
+            List<Data> finalDataList = dataList;
+            extractedImagesFileList.forEach(fileInfo -> {
+                //获取ID
+                Long id = ids.get(countSize.get());
+                Data data = new Data();
+                //拷贝dataInfo
+                BeanUtils.copyProperties(dataInfo, data);
+                //检测是否标注
+                if (checkLabeled(fileInfo.getPath())) {
+                    labeledList.add(Boolean.TRUE);
+                    data.setLabeled(Boolean.TRUE);
+                } else {
+                    labeledList.add(Boolean.FALSE);
+                    data.setLabeled(Boolean.FALSE);
+                }
+                try {
+                    Path path = Paths.get(fileInfo.getAbsolutePath());
+                    // 使用Files类的readAttributes方法获取文件的基本属性
+                    BasicFileAttributes attrs = Files.readAttributes(path, BasicFileAttributes.class);
+                    // 获取文件的修改时间
+                    Instant creationTime = attrs.lastModifiedTime().toInstant();
+                    // 将Instant转换为Date
+                    Date date = Date.from(creationTime);
+                    //设置文件创建时间
+                    data.setId(id);
+                    data.setGatherTime(date);
+                    data.setName(fileInfo.getName());
+                    //更改图片文件名称
+                    String fileHeaderSuffix = StringUtils.substring(fileInfo.getName(), fileInfo.getName().lastIndexOf("."), fileInfo.getName().length());
+                    String destInfo = fileInfo.getPath().replaceAll(fileInfo.getName(), "");
+                    File newFile = new File(destInfo, id + fileHeaderSuffix);
+                    File odlFile = new File(destInfo, fileInfo.getName());
+                    log.info("更改用户上传图片文件名称:{}", odlFile.renameTo(newFile));
+                    String imagePath = FileUploadUtils.getPathFileName(destInfo, id + fileHeaderSuffix);
+                    data.setUrl(imagePath);
+                    //该图片有标注,更改标注文件名称
+                    if (data.getLabeled()) {
+                        String labeledPath = fileInfo.getPath().replaceFirst("[.][^.]+$", "") + ".txt";
+                        File labeledNewFile = new File(destInfo, id + ".txt");
+                        File labeledOdlFile = new File(labeledPath);
+                        log.info("更改用户上传标注文件名称:{}", labeledOdlFile.renameTo(labeledNewFile));
+                        String labelUrl = FileUploadUtils.getPathFileName(destInfo, id + ".txt");
+                        data.setLabelurl(labelUrl);
+                    }
+                    finalDataList.add(data);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+            countSize.getAndIncrement();
+        });
+        // TODO 李兆晏 确认逻辑是否正确 start
+//            //是否选择已标注,如果已标注则需要检测所有图片是否标注
+//            if (!labeledList.isEmpty()) {
+//                // 已标注数量
+//                long unmarkedCount = labeledList.stream().filter(Boolean.FALSE::equals).count();
+//                // 未标注数量
+//                long markedCount = labeledList.stream().filter(Boolean.TRUE::equals).count();
+//                // 如果存在未标注的数据,则返回错误信息
+//                if (unmarkedCount > 0 && labeled) {
+//                    String format = String.format("错误: 已标注文件 %d 个,未标注文件 %d 个", markedCount, unmarkedCount);
+//                    return CommonResult.fail(format);
+//                }
+//                // 如果存在标注的数据,则返回错误信息
+//                if (markedCount > 0 && !labeled) {
+//                    String format = String.format("错误: 已标注文件 %d 个,未标注文件 %d 个", markedCount, unmarkedCount);
+//                    return CommonResult.fail(format);
+//                }
+//            }
+        if (labeled) {
+            // 未标注数量
+            long unmarkedCount = labeledList.stream().filter(Boolean.FALSE::equals).count();
+            // 已标注数量
+            long markedCount = labeledList.stream().filter(Boolean.TRUE::equals).count();
+
+            if (unmarkedCount > 0) {
+                String format = String.format("错误: 已标注文件 %d 个,未标注文件 %d 个", markedCount, unmarkedCount);
+                return CommonResult.fail(format);
+            }
+        }
+        // TODO 李兆晏 确认逻辑是否正确 end
+
+        int batchSize = 100;
+        int totalSize = dataList.size();
+        int page = (int) Math.ceil((double) totalSize / batchSize); // 计算总批次数
+
+        for (int i = 0; i < page; i++) {
+            int fromIndex = i * batchSize;
+            int toIndex = Math.min(fromIndex + batchSize, totalSize); // 确保最后一批不越界
+            List<Data> batchList = new ArrayList<>(dataList.subList(fromIndex, toIndex)); // 生成子列表副本
+            dataMapper.insertBatch(batchList); // 批量插入
+        }
+    } catch (Exception e) {
+        log.error("[uploadDataInfo]数据集处理出现未知异常.e:", e);
+        return CommonResult.fail("系统异常!");
+    }
+        return CommonResult.success("数据集上传成功!");
+    }
+
     @Override
-    @Transactional
-    public CommonResult<Boolean> dataAmplify(DataAmplifyDto dataAmplifyDto) {
+//    @Transactional
+    public CommonResult<Boolean> dataAmplify(DataAmplifyDto amplifyDto) {
         DataAmplificationTaskBo dataAmplificationTaskBo = new DataAmplificationTaskBo();
-        dataAmplificationTaskBo.setName(dataAmplifyDto.getTaskName());
+        dataAmplificationTaskBo.setSubTaskId(1008611L);
+        dataAmplificationTaskBo.setName(amplifyDto.getTaskName());
         dataAmplificationTaskBo.setStatus(BizConstant.TASK_STATUS_PENDING);
-        dataAmplificationTaskBo.setDataBatchNums(dataAmplifyDto.getBatchNum());
-        dataAmplificationTaskBo.setAugmentationType(dataAmplifyDto.getAugmentationType());
-        dataAmplificationTaskBo.setParameters(JsonUtils.toJsonString(dataAmplifyDto.getOtherParams()));
+        dataAmplificationTaskBo.setDataBatchNums(amplifyDto.getBatchNum());
+        dataAmplificationTaskBo.setAugmentationType(amplifyDto.getAugmentationType());
+        dataAmplificationTaskBo.setParameters(JsonUtils.toJsonString(amplifyDto.getOtherParams()));
+//        dataAmplificationTaskBo.setOutputImagePath(String.join(";", records.values()));
+        dataAmplificationTaskBo.setRemarks("PERSISTENCE");
         dataAmplificationTaskBo.setDelFlag(0);
-        DataAmplificationTask insertedTask = dataAmplificationTaskService.insertTask(dataAmplificationTaskBo);
-        return this.amplifyForData(insertedTask.getId().toString());
+        boolean result = dataAmplificationTaskService.insert(dataAmplificationTaskBo);
+        DataAmplificationTaskBo b = dataAmplificationTaskService.getByNameAndSubTaskId(dataAmplificationTaskBo.getName(), dataAmplificationTaskBo.getSubTaskId());
+        return this.amplifyForData(b.getId().toString());
     }
 
     private void initFileInfo(String dest, List<File> extractedImagesFileList, boolean directory, String fileName) {
@@ -467,8 +604,15 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
         params.put("logPath", PATH_PREFIX + filepath + "/log.log");
         params.put("inputImagePath", PATH_PREFIX + filepath);
         params.put("otherParams", realOtherParams);
-
+        if ("PERSISTENCE".equals(taskVo.getRemarks())) {
+            taskVo.setOutputImagePath(PATH_PREFIX + filepath);
+        } else {
+            taskVo.setRemarks(PATH_PREFIX + filepath + "/log.log");
+        }
         String[] outputs = taskVo.getOutputImagePath().split(";");
+        if ("PERSISTENCE".equals(taskVo.getRemarks())) {
+            outputs[0] = outputs[0] + "/result";
+        }
         log.info("check outputs: {} ; {}", outputs, outputs[0]);
         taskVo.setInputImagePath(String.valueOf(outputs.length));
         for (String path : outputs) {