Bladeren bron

feat: masc测试完毕,优化部分逻辑

WANGKANG 5 maanden geleden
bovenliggende
commit
4b8da90eb7

+ 4 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/TrackSequence.java

@@ -5,6 +5,8 @@
 package com.taais.biz.domain;
 
 import java.util.Date;
+
+import com.alibaba.excel.annotation.ExcelProperty;
 import com.fasterxml.jackson.annotation.JsonFormat;
 import com.mybatisflex.annotation.Column;
 import com.mybatisflex.annotation.Id;
@@ -84,6 +86,8 @@ private static final long serialVersionUID = 1L;
     private String logPath;
 
     private Long parentTaskId;
+    private Long predictTaskId;
+    private Long inputLabelOssId;
 
     /* 目标检测模型id */
     private Long algorithmModelTargetDetectionId;

+ 5 - 2
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/bo/TrackSequenceBo.java

@@ -87,7 +87,7 @@ public class TrackSequenceBo extends BaseEntity {
     /**
      * $column.columnComment
      */
-    @NotNull(message = "上传文件不能为空")
+    // @NotNull(message = "上传文件不能为空")
     private Long inputOssId;
 
     /**
@@ -111,9 +111,12 @@ public class TrackSequenceBo extends BaseEntity {
     private String algorithmParameters;
     private String logPath;
     private Long parentTaskId;
-    @NotNull(message = "目标检测模型ID不能为空")
+    // @NotNull(message = "目标检测模型ID不能为空")
     private Long algorithmModelTargetDetectionId;
 
+    private Long predictTaskId; // 预测任务ID
+    private Long inputLabelOssId; // 输入标签文件OSS ID
+
     // 将其他参数存入Map
     private Map<String, Object> otherParams = new HashMap<>();
 

+ 6 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/vo/TrackSequenceImportVo.java

@@ -97,6 +97,12 @@ public class TrackSequenceImportVo implements Serializable
     @ExcelProperty(value = "父任务ID")
     private Long parentTaskId;
 
+    @ExcelProperty(value = "预测任务ID")
+    private Long predictTaskId;
+
+    @ExcelProperty(value = "输入标签文件OSS ID")
+    private Long inputLabelOssId;
+
     @ExcelProperty(value = "目标检测模型id")
     private Long algorithmModelTargetDetectionId;
 }

+ 6 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/vo/TrackSequenceVo.java

@@ -140,6 +140,12 @@ public class TrackSequenceVo extends BaseEntity implements Serializable {
     @ExcelProperty(value = "父任务ID")
     private Long parentTaskId;
 
+    @ExcelProperty(value = "预测任务ID")
+    private Long predictTaskId;
+
+    @ExcelProperty(value = "输入标签文件OSS ID")
+    private Long inputLabelOssId;
+
     @ExcelProperty(value = "目标检测模型id")
     private Long algorithmModelTargetDetectionId;
 }

+ 2 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/AlgorithmTaskTrackServiceImpl.java

@@ -217,7 +217,7 @@ public class AlgorithmTaskTrackServiceImpl extends BaseServiceImpl<AlgorithmTask
             ToInfrared toInfraredTask = toInfraredService.insert2(toInfraredBo);
 
             // 需要更新一下inputDatasetOssId,以便于创建trackSequencePredictTask任务
-            Long inputDatasetOssId_new = getNewOssId(toInfraredTask);
+            Long inputDatasetOssId_new = getNewOssId(toInfraredTask.getOutputPath());
             params.put("inputDatasetOssId", inputDatasetOssId_new);
         }
 
@@ -236,8 +236,7 @@ public class AlgorithmTaskTrackServiceImpl extends BaseServiceImpl<AlgorithmTask
         return true;
     }
 
-    private Long getNewOssId(ToInfrared toInfraredTask) {
-        String outputPath = toInfraredTask.getOutputPath();
+    public Long getNewOssId(String outputPath) {
         Path path = Paths.get(outputPath);
         String dirName = path.getFileName().toString();
         String fileName = Constants.RESOURCE_PREFIX + outputPath.substring(TaaisConfig.getProfile().length());

+ 71 - 138
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -29,9 +29,14 @@ import com.taais.common.json.utils.JsonUtils;
 import com.taais.common.orm.core.page.PageQuery;
 import com.taais.common.orm.core.service.impl.BaseServiceImpl;
 import com.taais.common.websocket.utils.WebSocketUtils;
+import com.taais.system.domain.SysOss;
 import com.taais.system.domain.vo.SysOssVo;
 import com.taais.system.service.ISysOssService;
+import com.taais.system.service.impl.SysOssServiceImpl;
+
 import jakarta.annotation.Resource;
+
+import org.apache.commons.lang3.ObjectUtils;
 import org.springframework.beans.BeanUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.core.io.FileSystemResource;
@@ -51,7 +56,6 @@ import java.util.*;
 import static com.taais.biz.constant.BizConstant.VideoStatus.NOT_START;
 import static com.taais.biz.domain.table.TrackSequenceTableDef.TRACK_SEQUENCE;
 import static com.taais.biz.service.impl.TargetDetectionServiceImpl.getFileSize;
-import static com.taais.biz.service.impl.TargetDetectionServiceImpl.port;
 import static com.taais.biz.service.impl.ToInfraredServiceImpl.*;
 import static com.taais.biz.service.impl.VideoStableServiceImpl.*;
 
@@ -128,21 +132,6 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         return algorithmModelTrack.getModelAddress();
     }
 
-    public String getLabelDirPath(Long ossId) {
-        try {
-            SysOssVo ossEntity = ossService.getById(ossId);
-            return getResourcePath(ossEntity);
-        } catch (Exception e) {
-            throw new RuntimeException("oss标签文件不存在");
-        }
-    }
-
-    private String getTxtDirPath(Long id) {
-        TrackSequence entity = getById(id);
-        return entity.getOutputPath();
-
-    }
-
     @Override
     public QueryWrapper query() {
         return super.query().from(TRACK_SEQUENCE);
@@ -222,6 +211,23 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         return PageResult.build(page);
     }
 
+    public TrackSequence packageTrackSequence(TrackSequenceBo entityBo, SysOssVo ossEntity) {
+        TrackSequence entity = new TrackSequence();
+        entity.setInputOssId(entityBo.getInputOssId());
+        entity.setUrl(ossEntity.getUrl());
+        entity.setZipFilePath(ossEntity.getFileName());
+        entity.setName(entityBo.getName());
+        entity.setStatus(NOT_START);
+        entity.setRemarks(entityBo.getRemarks());
+        entity.setAlgorithmModelId(entityBo.getAlgorithmModelId());
+        entity.setAlgorithmId(entityBo.getAlgorithmId());
+        entity.setParentTaskId(entityBo.getParentTaskId());
+        entity.setAlgorithmModelTargetDetectionId(entityBo.getAlgorithmModelTargetDetectionId());
+        entity.setInputLabelOssId(entityBo.getInputLabelOssId());
+        entity.setPredictTaskId(entityBo.getPredictTaskId());
+        return entity;
+    }
+
     /**
      * 新增注视轨迹序列
      *
@@ -231,46 +237,40 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
     @Override
     @Transactional
     public CommonResult insert(TrackSequenceBo entityBo) {
+        Long ossId = entityBo.getInputOssId();
+        if (ObjectUtil.isNotEmpty(entityBo.getPredictTaskId())) { // 如果存在预测任务id,则使用label ossId
+            ossId = entityBo.getInputLabelOssId();
+        }
+
         // 检查input_oss_id是否存在
-        if (ObjectUtil.isNull(entityBo.getInputOssId())) {
-            return CommonResult.fail("上传文件不能为空");
+        if (ObjectUtil.isNull(ossId)) {
+            throw new RuntimeException("oss文件不存在");
         }
 
-        SysOssVo ossEntity = ossService.getById(entityBo.getInputOssId());
+        SysOssVo ossEntity = ossService.getById(ossId);
         if (ObjectUtil.isNull(ossEntity)) {
-            return CommonResult.fail("oss文件不存在");
+            throw new RuntimeException("oss文件不存在");
         }
 
-        TrackSequence entity = new TrackSequence();
-        entity.setInputOssId(entityBo.getInputOssId());
-        entity.setUrl(ossEntity.getUrl());
-        entity.setZipFilePath(ossEntity.getFileName());
-        entity.setName(entityBo.getName());
-        entity.setStatus(NOT_START);
-        entity.setRemarks(entityBo.getRemarks());
-        entity.setAlgorithmModelId(entityBo.getAlgorithmModelId());
-        entity.setAlgorithmId(entityBo.getAlgorithmId());
-        entity.setParentTaskId(entityBo.getParentTaskId());
-        entity.setAlgorithmModelTargetDetectionId(entityBo.getAlgorithmModelTargetDetectionId());
+        TrackSequence entity = packageTrackSequence(entityBo, ossEntity);
         boolean flag = this.save(entity);
 
         if (!flag) {
-            return CommonResult.fail("新增失败");
+            throw new RuntimeException("新增注释轨迹序列任务失败");
         }
 
         entity = updateEntity(entity, entityBo, ossEntity);
 
         // 步骤 6. 保存算法参数到数据库
         boolean __ = this.updateById(entity);// 使用全局配置的雪花算法主键生成器生成ID值
-        if (__) {
-            return CommonResult.success();
-        } else {
-            return CommonResult.fail();
+        if (!__) {
+            throw new RuntimeException("新增注释轨迹序列任务失败");
         }
+        return CommonResult.success(entity);
     }
 
     @Transactional
-    public TrackSequence insert2(TrackSequenceBo entityBo) {
+    public TrackSequence insert2(TrackSequenceBo entityBo) { // 后面估计会优化掉此方法
         // 检查input_oss_id是否存在
         if (ObjectUtil.isNull(entityBo.getInputOssId())) {
             throw new RuntimeException("oss文件不存在");
@@ -281,16 +281,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             throw new RuntimeException("oss文件不存在");
         }
 
-        TrackSequence entity = new TrackSequence();
-        entity.setInputOssId(entityBo.getInputOssId());
-        entity.setUrl(ossEntity.getUrl());
-        entity.setZipFilePath(ossEntity.getFileName());
-        entity.setName(entityBo.getName());
-        entity.setStatus(NOT_START);
-        entity.setRemarks(entityBo.getRemarks());
-        entity.setAlgorithmModelId(entityBo.getAlgorithmModelId());
-        entity.setAlgorithmId(entityBo.getAlgorithmId());
-        entity.setParentTaskId(entityBo.getParentTaskId());
+        TrackSequence entity = packageTrackSequence(entityBo, ossEntity);
         boolean flag = this.save(entity);
 
         if (!flag) {
@@ -349,16 +340,23 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             if (ObjectUtil.isNotEmpty(entity.getAlgorithmModelTargetDetectionId())) {
                 String model_path_TD = getModelPath(entity.getAlgorithmModelTargetDetectionId());
                 result.put("model_path_TD", model_path_TD);
+            } else {
+                throw new RuntimeException("目标检测模型不存在!");
             }
         } else if (algorithmConfig.getType().equals(BizConstant.AlgorithmType.TEST)) {
-            String source_dir = getTestInputPath(ossEntity);
+            if (ObjectUtil.isEmpty(entityBo.getPredictTaskId())) {
+                throw new RuntimeException("预测任务id为空");
+            }
+            TrackSequence predictTask = getById(entityBo.getPredictTaskId());
+            if (ObjectUtil.isNull(predictTask)) {
+                throw new RuntimeException("预测任务不存在");
+            }
+
+            String source_dir = predictTask.getInputPath();
+            String txt_dir = predictTask.getOutputPath();
             String result_dir = getTestOutputPath(entity, ossEntity);
-            String log_path = getLogFilePath(result_dir, entity.getId(), BizConstant.TO_INFRARED_SUFFIX);
-            Long inputLabelOssId = entityBo.getOtherParams().get("inputLabelOssId") == null ? null
-                    : Long.parseLong((String) entityBo.getOtherParams().get("inputLabelOssId"));
-            String label_dir = getLabelDirPath(inputLabelOssId);
-            Long trackSequencePredictTaskId = (Long) entityBo.getOtherParams().get("trackSequencePredictTaskId");
-            String txt_dir = getTxtDirPath(trackSequencePredictTaskId);
+            String log_path = getLogFilePath(result_dir, entity.getId(), BizConstant.TRACK_SEQUENCE_SUFFIX);
+            String label_dir = getResourcePath(ossEntity);
             result.put("source_dir", source_dir);
             result.put("result_dir", result_dir);
             result.put("log_path", log_path);
@@ -548,7 +546,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
                 algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator
                         + ((HashMap<String, String>) parse.get("dataset")).get("classes"));
             } catch (Exception e) {
-                System.out.println("未知错误,我也不知道啥原因。。");
+                System.out.println("未知错误");
             }
 
             System.out.println(parse.get("dataset"));
@@ -565,14 +563,11 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
     @Override
     public CommonResult getLog(Long id) {
         TrackSequence entity = getById(id);
-        String outPutPath = entity.getOutputPath();
-        String logPath = outPutPath + File.separator + getLogFileName(entity);
-        System.out.println(logPath);
-        File file = new File(logPath);
+        File file = new File(entity.getLogPath());
         if (!file.exists()) {
             return CommonResult.fail("日志文件不存在!");
         }
-        return CommonResult.success(readLogContent(logPath), "success");
+        return CommonResult.success(readLogContent(entity.getLogPath()), "success");
     }
 
     @Override
@@ -612,19 +607,9 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     @Override
     public CommonResult addEvaluate(AddEvaluate addEvaluate) {
-        // 检查任务名称
-        if (ObjectUtil.isEmpty(addEvaluate.getName())
-                || (!addEvaluate.getName().startsWith(MASC) && !addEvaluate.getName().startsWith(CAT))) {
-            return CommonResult.fail("任务命名错误,需以MASC或CAT开头!");
-        }
-
-        if (ObjectUtil.isEmpty(addEvaluate.getAlgorithmId())) {
-            return CommonResult.fail("请指定算法!");
-        }
-
         AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(addEvaluate.getAlgorithmId());
         if (ObjectUtil.isNull(config)) {
-            return CommonResult.fail("找不到指定的算法!");
+            throw new RuntimeException("算法不存在");
         }
 
         TrackSequence entity = this.getById(addEvaluate.getId());
@@ -632,14 +617,16 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         evaluate_entity.setStatus(NOT_START);
         evaluate_entity.setUrl(entity.getUrl());
         if (addEvaluate.getName().startsWith(MASC)) {
-            evaluate_entity.setInputPath(entity.getInputPath() + ";" + entity.getOutputPath() + File.separator + "gaze"
+            evaluate_entity.setInputPath(entity.getInputPath() + ";" +
+                    entity.getOutputPath() + File.separator + "gaze"
                     + File.separator + "txt");
         } else if (addEvaluate.getName().startsWith(CAT)) {
             File file________ = new File(entity.getInputPath());
             if (!file________.exists()) {
                 return CommonResult.fail("数据集为空!");
             }
-            evaluate_entity.setInputPath(entity.getInputPath() + ";" + entity.getOutputPath() + File.separator + "txt");
+            evaluate_entity.setInputPath(entity.getInputPath() + ";" +
+                    entity.getOutputPath() + File.separator + "txt");
         } else {
             return CommonResult.fail("命名错误!");
         }
@@ -656,7 +643,9 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         if (flag__) {
             Path path = Paths.get(entity.getOutputPath());
             evaluate_entity.setOutputPath(path
-                    .resolveSibling(evaluate_entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX).toString());
+                    .resolveSibling(evaluate_entity.getId().toString() +
+                            BizConstant.TRACK_SEQUENCE_SUFFIX)
+                    .toString());
             boolean flag___ = updateById(evaluate_entity);
             if (flag___) {
                 return CommonResult.success("新增评估任务成功!");
@@ -673,37 +662,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         TrackSequence entity = getById(id);
         String outPutPath = entity.getOutputPath();
         String modelPath = outPutPath;
-        File modelDir = new File(modelPath);
-        if (!modelDir.exists()) {
-            return CommonResult.fail("模型输出目录不存在!");
-        }
-        File[] files = modelDir.listFiles();
-        Integer idx = 0;
-        ArrayList<Map<String, String>> res = new ArrayList<>();
-        for (File file : files) {
-            if (!file.isDirectory()) {
-                continue;
-            }
-            idx += 1;
-            Map<String, String> tmp = new HashMap<>();
-            tmp.put("id", idx.toString());
-            tmp.put("name", file.getName());
-            String filePath = file.getPath();
-            tmp.put("path", filePath);
-            // todo 获取真正的url
-            // http://localhost:9090/profile/upload/2024/10/27/1_1729404909511_20241027153840A001.zip
-            String url = "http://localhost:" + port + Constants.RESOURCE_PREFIX
-                    + filePath.substring(TaaisConfig.getProfile().length());
-            url = url.replaceAll("\\\\", "/"); // windows
-            tmp.put("url", url);
-
-            double fileSize = (getFileSize(file) / (1024.0 * 1024.0));
-            DecimalFormat decimalFormat = new DecimalFormat("#.##");
-            String formatFileSize = decimalFormat.format(fileSize);
-            tmp.put("size", formatFileSize + "MB");
-            res.add(tmp);
-        }
-        return CommonResult.success(res, "success");
+        return getModelList_(modelPath, TaaisConfig.getProfile());
     }
 
     @Override
@@ -740,48 +699,22 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     @Override
     public CommonResult previewEvaluateResult(Long id) {
-        TrackSequence entity = getById(id);
-
-        SysOssVo inputOssEntity = ossService.getById(entity.getInputOssId());
-
-        String filePath = inputOssEntity.getFileName();
-        String localPath = TaaisConfig.getProfile();
-        String resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
-
-        String fileName = StringUtils.substringAfterLast(filePath, "/");
-        String fileName_without_suffix = removeFileExtension(fileName);
-
-        Path path = Paths.get(resourcePath);
-        Path resultPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX
-                + File.separator + BizConstant.RESULT_JSON_NAME);
-
-        if (!new File(resultPath.toString()).exists()) {
-            return CommonResult.fail("评估结果文件不存在!");
-        }
-
-        List<Dict> resultMap = parseJsonMapList(resultPath);
-        if (ObjectUtil.isEmpty(resultMap)) {
-            return CommonResult.fail("获取结果文件失败");
-        }
-
-        return CommonResult.success(resultMap);
+        return getStatisticsResult(id);
     }
 
     @Override
     public CommonResult getStatisticsResult(Long id) {
         TrackSequence entity = getById(id);
-        SysOssVo inputOssEntity = ossService.getById(entity.getInputOssId());
 
-        String filePath = inputOssEntity.getFileName();
-        String localPath = TaaisConfig.getProfile();
-        String resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
+        String resultFilePath = entity.getOutputPath() + File.separator + BizConstant.RESULT_JSON_NAME;
 
-        Path path = Paths.get(resourcePath);
-        Path outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX);
+        Path resultPath = Paths.get(resultFilePath);
 
-        Path statisticsResultPath = outputPath.resolve(BizConstant.RESULT_JSON_NAME);
+        if (!new File(resultPath.toString()).exists()) {
+            return CommonResult.fail("评估结果文件不存在!");
+        }
 
-        List<Dict> resultMap = parseJsonMapList(statisticsResultPath);
+        List<Dict> resultMap = parseJsonMapList(resultPath);
         if (ObjectUtil.isEmpty(resultMap)) {
             return CommonResult.fail("获取结果文件失败");
         }