Bladeren bron

feat: MASC预测

WANGKANG 5 maanden geleden
bovenliggende
commit
589f75aca7

+ 3 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/TrackSequence.java

@@ -84,4 +84,7 @@ private static final long serialVersionUID = 1L;
     private String logPath;
 
     private Long parentTaskId;
+
+    /* 目标检测模型id */
+    private Long algorithmModelTargetDetectionId;
 }

+ 5 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/bo/TrackSequenceBo.java

@@ -6,6 +6,7 @@ package com.taais.biz.domain.bo;
 
 import com.fasterxml.jackson.annotation.JsonAnySetter;
 import com.fasterxml.jackson.annotation.JsonFormat;
+import com.fasterxml.jackson.annotation.JsonProperty;
 import com.taais.biz.domain.TrackSequence;
 import com.taais.common.orm.core.domain.BaseEntity;
 import io.github.linpeilie.annotations.AutoMapper;
@@ -18,6 +19,8 @@ import java.util.Date;
 import java.util.HashMap;
 import java.util.Map;
 
+import org.springframework.beans.factory.annotation.Value;
+
 /**
  * 注视轨迹序列业务对象 track_sequence
  *
@@ -104,11 +107,12 @@ public class TrackSequenceBo extends BaseEntity {
 
     private Long algorithmModelId;
 
-    @NotNull(message = "算法不能为空")
     private Long algorithmId;
     private String algorithmParameters;
     private String logPath;
     private Long parentTaskId;
+    @NotNull(message = "目标检测模型ID不能为空")
+    private Long algorithmModelTargetDetectionId;
 
     // 将其他参数存入Map
     private Map<String, Object> otherParams = new HashMap<>();

+ 3 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/vo/TrackSequenceImportVo.java

@@ -96,4 +96,7 @@ public class TrackSequenceImportVo implements Serializable
 
     @ExcelProperty(value = "父任务ID")
     private Long parentTaskId;
+
+    @ExcelProperty(value = "目标检测模型id")
+    private Long algorithmModelTargetDetectionId;
 }

+ 3 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/vo/TrackSequenceVo.java

@@ -139,4 +139,7 @@ public class TrackSequenceVo extends BaseEntity implements Serializable {
 
     @ExcelProperty(value = "父任务ID")
     private Long parentTaskId;
+
+    @ExcelProperty(value = "目标检测模型id")
+    private Long algorithmModelTargetDetectionId;
 }

+ 3 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/ToInfraredServiceImpl.java

@@ -50,6 +50,7 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.text.DecimalFormat;
 import java.util.*;
+import java.util.stream.Collectors;
 
 import static com.taais.biz.constant.BizConstant.VideoStatus.NOT_START;
 import static com.taais.biz.domain.table.ToInfraredTableDef.TO_INFRARED;
@@ -249,7 +250,8 @@ public class ToInfraredServiceImpl extends BaseServiceImpl<ToInfraredMapper, ToI
                 return Double.parseDouble(val);
             } catch (Exception e2) {
                 if (val.contains(",")) {
-                    return Arrays.asList(val.split(","));
+                    val = val.replaceAll("\\[", "").replaceAll("\\]", "").replaceAll("\\s", "");
+                    return Arrays.stream(val.split(",")).map(Integer::parseInt).collect(Collectors.toList());
                 } else {
                     return Boolean.parseBoolean(val);
                 }

+ 60 - 33
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -59,10 +59,11 @@ import static com.taais.biz.service.impl.VideoStableServiceImpl.*;
  * 注视轨迹序列Service业务层处理
  *
  * @author wangkang
- * 2024-09-22
+ *         2024-09-22
  */
 @Service
-public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMapper, TrackSequence> implements ITrackSequenceService {
+public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMapper, TrackSequence>
+        implements ITrackSequenceService {
 
     private static final String MASC = "MASC";
     private static final String CAT = "CAT";
@@ -95,7 +96,6 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         return outputPath.toString();
     }
 
-
     public String getPredictInputPath(SysOssVo ossEntity) {
         // todo
         return getUnZipDirPath(ossEntity);
@@ -115,8 +115,16 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         return getTrainOutputPath(entity, ossEntity);
     }
 
-    private String getModelPath(TrackSequence entity) {
-        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
+    private String getModelPath(Long modelId) {
+        if (ObjectUtil.isNull(modelId)) {
+            System.out.println("模型id为空");
+            return "";
+        }
+        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(modelId);
+        if (ObjectUtil.isNull(algorithmModelTrack)) {
+            System.out.println("模型不存在");
+            return "";
+        }
         return algorithmModelTrack.getModelAddress();
     }
 
@@ -135,7 +143,6 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     }
 
-
     @Override
     public QueryWrapper query() {
         return super.query().from(TRACK_SEQUENCE);
@@ -244,6 +251,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         entity.setAlgorithmModelId(entityBo.getAlgorithmModelId());
         entity.setAlgorithmId(entityBo.getAlgorithmId());
         entity.setParentTaskId(entityBo.getParentTaskId());
+        entity.setAlgorithmModelTargetDetectionId(entityBo.getAlgorithmModelTargetDetectionId());
         boolean flag = this.save(entity);
 
         if (!flag) {
@@ -308,13 +316,14 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             throw new RuntimeException("算法配置参数为空");
         }
 
-        Map<String, Object> algorithmParameters = getAlgorithmParameters(algorithmConfig.getParameters(), entityBo.getOtherParams());
+        Map<String, Object> algorithmParameters = getAlgorithmParameters(algorithmConfig.getParameters(),
+                entityBo.getOtherParams());
 
         // 步骤4. 构造可以直接传给前端的map数据结构
         Map<String, Object> result = new HashMap<>();
 
-        result.put("biz_id", entity.getId());
-        result.put("biz_type", BizConstant.BizType.TRACK_SEQUENCE);
+        result.put("bizId", entity.getId());
+        result.put("bizType", BizConstant.BizType.TRACK_SEQUENCE);
 
         if (algorithmConfig.getType().equals(BizConstant.AlgorithmType.TRAIN)) {
             String source_dir = getTrainInputPath(ossEntity);
@@ -327,23 +336,26 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             String source_dir = getPredictInputPath(ossEntity);
             String result_dir = getPredictOutputPath(entity, ossEntity);
             String log_path = getLogFilePath(result_dir, entity.getId(), BizConstant.TRACK_SEQUENCE_SUFFIX);
-            String model_path = getModelPath(entity);
+
             result.put("source_dir", source_dir);
             result.put("result_dir", result_dir);
             result.put("log_path", log_path);
-            result.put("model_path", model_path);
 
-            if (ObjectUtil.isNotEmpty(entityBo.getOtherParams().get("algorithmModelTargetDetectionId"))) {
-                Long model_TD_Id = Long.parseLong((String) entityBo.getOtherParams().get("algorithmModelTargetDetectionId"));
-                AlgorithmModelTrack model_TD = algorithmModelTrackService.getById(model_TD_Id);
-                String model_path_TD = model_TD.getModelAddress();
+            if (ObjectUtil.isNotEmpty(entity.getAlgorithmModelId())) {
+                String model_path = getModelPath(entity.getAlgorithmModelId());
+                result.put("model_path", model_path);
+            }
+
+            if (ObjectUtil.isNotEmpty(entity.getAlgorithmModelTargetDetectionId())) {
+                String model_path_TD = getModelPath(entity.getAlgorithmModelTargetDetectionId());
                 result.put("model_path_TD", model_path_TD);
             }
         } else if (algorithmConfig.getType().equals(BizConstant.AlgorithmType.TEST)) {
             String source_dir = getTestInputPath(ossEntity);
             String result_dir = getTestOutputPath(entity, ossEntity);
             String log_path = getLogFilePath(result_dir, entity.getId(), BizConstant.TO_INFRARED_SUFFIX);
-            Long inputLabelOssId = entityBo.getOtherParams().get("inputLabelOssId") == null ? null : Long.parseLong((String) entityBo.getOtherParams().get("inputLabelOssId"));
+            Long inputLabelOssId = entityBo.getOtherParams().get("inputLabelOssId") == null ? null
+                    : Long.parseLong((String) entityBo.getOtherParams().get("inputLabelOssId"));
             String label_dir = getLabelDirPath(inputLabelOssId);
             Long trackSequencePredictTaskId = (Long) entityBo.getOtherParams().get("trackSequencePredictTaskId");
             String txt_dir = getTxtDirPath(trackSequencePredictTaskId);
@@ -462,8 +474,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         String outputPath = "";
         String zipFilePath = "";
 
-        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
-        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
+        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService
+                .getById(entity.getAlgorithmId());
 
         if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
             outputPath = entity.getOutputPath() + File.separator + "predict";
@@ -501,7 +513,9 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         }
 
         org.springframework.core.io.Resource resource = new FileSystemResource(file);
-        return ResponseEntity.ok().header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + file.getName() + "\"").header(HttpHeaders.CONTENT_TYPE, "application/octet-stream").body(resource);
+        return ResponseEntity.ok()
+                .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + file.getName() + "\"")
+                .header(HttpHeaders.CONTENT_TYPE, "application/octet-stream").body(resource);
     }
 
     @Override
@@ -521,15 +535,18 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         updateById(entity);
 
         AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
-        algorithmModelTrack.setModelStatus("200".equals(status) ? BizConstant.ModelStatus.END : BizConstant.ModelStatus.FAILED);
+        algorithmModelTrack
+                .setModelStatus("200".equals(status) ? BizConstant.ModelStatus.END : BizConstant.ModelStatus.FAILED);
 
-        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
+        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService
+                .getById(algorithmModelTrack.getAlgorithmId());
         String params = algorithmConfigTrack.getParameterConfig();
         HashMap<String, Object> parse = (HashMap<String, Object>) JSON.parse((params));
 
         if ("200".equals(status) && ObjectUtil.isNull(algorithmModelTrack.getModelAddress())) {
             try {
-                algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator + ((HashMap<String, String>) parse.get("dataset")).get("classes"));
+                algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator
+                        + ((HashMap<String, String>) parse.get("dataset")).get("classes"));
             } catch (Exception e) {
                 System.out.println("未知错误,我也不知道啥原因。。");
             }
@@ -576,7 +593,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             return CommonResult.success(res, "success");
         } else if (entity.getName().startsWith(CAT)) {
             AlgorithmModelTrack modelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
-            AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(modelTrack.getAlgorithmId());
+            AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService
+                    .getById(modelTrack.getAlgorithmId());
 
             AlgorithmModelTrackVo res = new AlgorithmModelTrackVo();
 
@@ -595,7 +613,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
     @Override
     public CommonResult addEvaluate(AddEvaluate addEvaluate) {
         // 检查任务名称
-        if (ObjectUtil.isEmpty(addEvaluate.getName()) || (!addEvaluate.getName().startsWith(MASC) && !addEvaluate.getName().startsWith(CAT))) {
+        if (ObjectUtil.isEmpty(addEvaluate.getName())
+                || (!addEvaluate.getName().startsWith(MASC) && !addEvaluate.getName().startsWith(CAT))) {
             return CommonResult.fail("任务命名错误,需以MASC或CAT开头!");
         }
 
@@ -613,7 +632,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         evaluate_entity.setStatus(NOT_START);
         evaluate_entity.setUrl(entity.getUrl());
         if (addEvaluate.getName().startsWith(MASC)) {
-            evaluate_entity.setInputPath(entity.getInputPath() + ";" + entity.getOutputPath() + File.separator + "gaze" + File.separator + "txt");
+            evaluate_entity.setInputPath(entity.getInputPath() + ";" + entity.getOutputPath() + File.separator + "gaze"
+                    + File.separator + "txt");
         } else if (addEvaluate.getName().startsWith(CAT)) {
             File file________ = new File(entity.getInputPath());
             if (!file________.exists()) {
@@ -635,7 +655,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
         if (flag__) {
             Path path = Paths.get(entity.getOutputPath());
-            evaluate_entity.setOutputPath(path.resolveSibling(evaluate_entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX).toString());
+            evaluate_entity.setOutputPath(path
+                    .resolveSibling(evaluate_entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX).toString());
             boolean flag___ = updateById(evaluate_entity);
             if (flag___) {
                 return CommonResult.success("新增评估任务成功!");
@@ -671,7 +692,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             tmp.put("path", filePath);
             // todo 获取真正的url
             // http://localhost:9090/profile/upload/2024/10/27/1_1729404909511_20241027153840A001.zip
-            String url = "http://localhost:" + port + Constants.RESOURCE_PREFIX + filePath.substring(TaaisConfig.getProfile().length());
+            String url = "http://localhost:" + port + Constants.RESOURCE_PREFIX
+                    + filePath.substring(TaaisConfig.getProfile().length());
             url = url.replaceAll("\\\\", "/"); // windows
             tmp.put("url", url);
 
@@ -698,15 +720,20 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         String fileName_without_suffix = removeFileExtension(fileName);
 
         Path path = Paths.get(resourcePath);
-        Path inputPath = path.resolveSibling(fileName_without_suffix + BizConstant.UNZIP_SUFFIX + File.separator + "images");
-        Path outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + "gaze" + File.separator + "images");
+        Path inputPath = path
+                .resolveSibling(fileName_without_suffix + BizConstant.UNZIP_SUFFIX + File.separator + "images");
+        Path outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX
+                + File.separator + "gaze" + File.separator + "images");
 
         File outputPathDir = new File(outputPath.toString());
         if (!outputPathDir.exists()) {
-            outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + "images");
+            outputPath = path.resolveSibling(
+                    entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + "images");
         }
 
-//        String urlPrefix = inputOssEntity.getUrl().substring(0, inputOssEntity.getUrl().indexOf(Constants.RESOURCE_PREFIX) + Constants.RESOURCE_PREFIX.length());
+        // String urlPrefix = inputOssEntity.getUrl().substring(0,
+        // inputOssEntity.getUrl().indexOf(Constants.RESOURCE_PREFIX) +
+        // Constants.RESOURCE_PREFIX.length());
         String urlPrefix = Constants.RESOURCE_PREFIX;
         return getCompareImage(urlPrefix, inputPath.toString(), outputPath.toString());
     }
@@ -725,7 +752,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         String fileName_without_suffix = removeFileExtension(fileName);
 
         Path path = Paths.get(resourcePath);
-        Path resultPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + BizConstant.RESULT_JSON_NAME);
+        Path resultPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX
+                + File.separator + BizConstant.RESULT_JSON_NAME);
 
         if (!new File(resultPath.toString()).exists()) {
             return CommonResult.fail("评估结果文件不存在!");
@@ -761,5 +789,4 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         return CommonResult.success(resultMap);
     }
 
-
 }