瀏覽代碼

feat: 目标检测任务展示编辑完成,路径还未完成

WANGKANG 5 月之前
父節點
當前提交
82f6e7280a
共有 1 個文件被更改,包括 105 次插入102 次删除
  1. 105 102
      taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetDetectionServiceImpl.java

+ 105 - 102
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetDetectionServiceImpl.java

@@ -24,6 +24,7 @@ import com.taais.common.core.core.domain.CommonResult;
 import com.taais.common.core.core.page.PageResult;
 import com.taais.common.core.utils.MapstructUtils;
 import com.taais.common.core.utils.StringUtils;
+import com.taais.common.json.utils.JsonUtils;
 import com.taais.common.orm.core.page.PageQuery;
 import com.taais.common.orm.core.service.impl.BaseServiceImpl;
 import com.taais.common.websocket.utils.WebSocketUtils;
@@ -48,7 +49,7 @@ import java.util.*;
 
 import static com.taais.biz.constant.BizConstant.VideoStatus.NOT_START;
 import static com.taais.biz.domain.table.TargetDetectionTableDef.TARGET_DETECTION;
-import static com.taais.biz.service.impl.ToInfraredServiceImpl.readLogContent;
+import static com.taais.biz.service.impl.ToInfraredServiceImpl.*;
 import static com.taais.biz.service.impl.VideoStableServiceImpl.*;
 
 /**
@@ -62,9 +63,6 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
     @Value("${server.port}")
     public static String port;
 
-    @Value("${server.task_stop_url}")
-    private String task_stop_url;
-
     @Resource
     private TargetDetectionMapper targetDetectionMapper;
     @Autowired
@@ -120,7 +118,7 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
     @Override
     public List<TargetDetectionVo> selectList(TargetDetectionBo targetDetectionBo) {
         QueryWrapper queryWrapper = buildQueryWrapper(targetDetectionBo);
-        queryWrapper.orderBy(TARGET_DETECTION.CREATE_TIME,Boolean.FALSE);
+        queryWrapper.orderBy(TARGET_DETECTION.CREATE_TIME, Boolean.FALSE);
         return this.listAs(queryWrapper, TargetDetectionVo.class);
     }
 
@@ -133,17 +131,22 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
     @Override
     public PageResult<TargetDetectionVo> selectPage(TargetDetectionBo targetDetectionBo) {
         QueryWrapper queryWrapper = buildQueryWrapper(targetDetectionBo);
-        queryWrapper.orderBy(TARGET_DETECTION.CREATE_TIME,Boolean.FALSE);
+        queryWrapper.orderBy(TARGET_DETECTION.CREATE_TIME, Boolean.FALSE);
         Page<TargetDetectionVo> page = this.pageAs(PageQuery.build(), queryWrapper, TargetDetectionVo.class);
         page.getRecords().forEach(entity -> {
             Long modelId = entity.getAlgorithmModelId();
-            AlgorithmModelTrackVo model = algorithmModelTrackService.selectById(modelId);
-            if (ObjectUtil.isNotNull(model)) {
-                AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(model.getAlgorithmId());
+            if (ObjectUtil.isNotNull(modelId)) {
+                AlgorithmModelTrack model = algorithmModelTrackService.getById(modelId);
+                entity.setModelName(model.getModelName());
+
+            }
+
+            Long algorithmId = entity.getAlgorithmId();
+            if (ObjectUtil.isNotNull(algorithmId)) {
+                AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(algorithmId);
                 entity.setType(config.getType());
                 entity.setSubsystem(config.getSubsystem());
                 entity.setAlgorithmName(config.getAlgorithmName());
-                entity.setModelName(model.getModelName());
             }
         });
         page.getRecords().sort(Comparator.comparing(TargetDetectionVo::getCreateTime).reversed());
@@ -153,47 +156,94 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
     /**
      * 新增目标检测
      *
-     * @param targetDetectionBo 目标检测Bo
+     * @param entityBo 目标检测Bo
      * @return 结果:true 操作成功,false 操作失败
      */
     @Override
-    public boolean insert(TargetDetectionBo targetDetectionBo) {
+    @Transactional
+    public boolean insert(TargetDetectionBo entityBo) {
         // 检查input_oss_id是否存在
-        if (ObjectUtil.isNull(targetDetectionBo.getInputOssId())) {
+        if (ObjectUtil.isNull(entityBo.getInputOssId())) {
             return false;
         }
 
-        SysOssVo ossEntity = ossService.getById(targetDetectionBo.getInputOssId());
+        SysOssVo ossEntity = ossService.getById(entityBo.getInputOssId());
         if (ObjectUtil.isNull(ossEntity)) {
             return false;
         }
 
-        TargetDetection targetDetection = new TargetDetection();
+        TargetDetection entity = new TargetDetection();
+        entity.setInputOssId(entityBo.getInputOssId());
+        entity.setUrl(ossEntity.getUrl());
+        entity.setZipFilePath(ossEntity.getFileName());
+        entity.setName(entityBo.getName());
+        entity.setStatus(NOT_START);
+        entity.setRemarks(entityBo.getRemarks());
+        entity.setAlgorithmModelId(entityBo.getAlgorithmModelId());
+        entity.setAlgorithmId(entityBo.getAlgorithmId());
+        boolean flag = this.save(entity);
+
+        if (!flag) {
+            return false;
+        }
 
-        targetDetection.setInputOssId(targetDetectionBo.getInputOssId());
-        targetDetection.setUrl(ossEntity.getUrl());
+        entity = updateEntity(entity, entityBo, ossEntity);
 
-        String filePath = ossEntity.getFileName();
-        String localPath = TaaisConfig.getProfile();
-        String resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
-        targetDetection.setInputPath(resourcePath);
+        // 步骤 6. 保存算法参数到数据库
+        return this.updateById(entity);// 使用全局配置的雪花算法主键生成器生成ID值
+    }
 
-        String fileName = StringUtils.substringAfterLast(filePath, "/");
-        String fileName_without_suffix = removeFileExtension(fileName);
+    private TargetDetection updateEntity(TargetDetection entity, TargetDetectionBo entityBo, SysOssVo ossEntity) {
+        // 从这里开始,配置任务的algorithm_parameters参数
+        // 步骤 1. 首先根据算法id获取算法配置
+        AlgorithmConfigTrack algorithmConfig = algorithmConfigTrackService.getById(entity.getAlgorithmId());
+        if (ObjectUtil.isNull(algorithmConfig)) {
+            throw new RuntimeException("算法配置参数为空");
+        }
 
-        Path path = Paths.get(resourcePath);
-        Path outPath = path.resolveSibling(fileName_without_suffix + "_images" + System.currentTimeMillis());
-        targetDetection.setOutputPath(outPath.toString());
+        Map<String, Object> algorithmParameters = getAlgorithmParameters(algorithmConfig.getParameters(), entityBo.getOtherParams());
 
-        targetDetection.setZipFilePath(path.resolveSibling(fileName_without_suffix + ".zip").toString());
+        // 步骤4. 构造可以直接传给前端的map数据结构
 
-        targetDetection.setName(targetDetectionBo.getName());
-        targetDetection.setStatus(NOT_START);
-        targetDetection.setRemarks(targetDetectionBo.getRemarks());
+        Map<String, Object> result = getCommonResultParams(entity.getId(), ossEntity);
 
-        targetDetection.setAlgorithmModelId(targetDetectionBo.getAlgorithmModelId());
+        if (algorithmConfig.getType().equals(BizConstant.AlgorithmType.REASONING)) {
+            String model_path = getModelPath(entity);
+            result.put("model_path", model_path);
+        }
+        result.put("otherParams", algorithmParameters);
+
+        entity.setInputPath((String) result.get("source_dir"));
+        entity.setOutputPath((String) result.get("result_dir"));
+        entity.setLogPath((String) result.get("log_path"));
+
+        // 步骤 5. 将算法参数map序列化为json字符串,保存到数据库中
+        entity.setAlgorithmParameters(JsonUtils.toJsonString(result));
+
+//        makeDir(entity.getInputPath());
+//        makeDir(entity.getOutputPath());
+//
+//        File file = new File(entity.getInputPath());
+//        if (!file.exists()) {
+//            String resourcePath = getResourcePath(ossService.getById(entity.getInputOssId()));
+//            ZipUtils.unzip(resourcePath, entity.getInputPath());
+//        }
+
+        return entity;
+    }
 
-        return this.save(targetDetection);// 使用全局配置的雪花算法主键生成器生成ID值
+    private Map<String, Object> getCommonResultParams(Long id, SysOssVo ossEntity) {
+        Map<String, Object> result = new HashMap<>();
+        String source_dir = getTrainInputPath(ossEntity);
+        String result_dir = getTrainOutputPath(id, ossEntity);
+        String log_path = getLogFilePath(result_dir, id, BizConstant.TARGET_DETECTION_SUFFIX);
+
+        result.put("biz_id", id);
+        result.put("biz_type", BizConstant.BizType.TARGET_DETECTION);
+        result.put("source_dir", source_dir);
+        result.put("result_dir", result_dir);
+        result.put("log_path", log_path);
+        return result;
     }
 
     /**
@@ -204,9 +254,17 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
      */
     @Override
     public boolean update(TargetDetectionBo targetDetectionBo) {
-        TargetDetection targetDetection = MapstructUtils.convert(targetDetectionBo, TargetDetection.class);
-        if (ObjectUtil.isNotNull(targetDetection) && ObjectUtil.isNotNull(targetDetection.getId())) {
-            boolean updated = this.updateById(targetDetection);
+        TargetDetection entity = MapstructUtils.convert(targetDetectionBo, TargetDetection.class);
+
+        SysOssVo ossEntity = ossService.getById(targetDetectionBo.getInputOssId());
+        if (ObjectUtil.isNull(ossEntity)) {
+            throw new RuntimeException("oss文件不存在");
+        }
+
+        entity = updateEntity(entity, targetDetectionBo, ossEntity);
+
+        if (ObjectUtil.isNotNull(entity) && ObjectUtil.isNotNull(entity.getId())) {
+            boolean updated = this.updateById(entity);
             return updated;
         }
         return false;
@@ -229,7 +287,7 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
         return null;
     }
 
-    public String getTrainOutputPath(ToInfrared entity, SysOssVo ossEntity) {
+    public String getTrainOutputPath(Long id, SysOssVo ossEntity) {
         // todo
         return null;
     }
@@ -244,77 +302,20 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
         return null;
     }
 
-    private String getLogFileName(TargetDetection entity) {
-        return entity.getId() + BizConstant.TARGET_DETECTION_SUFFIX + ".log";
-    }
-
-    public String getLogFilePath(TargetDetection entity) {
-        return entity.getOutputPath() + File.separator + getLogFileName(entity);
+    private String getModelPath(TargetDetection entity) {
+        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
+        return algorithmModelTrack.getModelAddress();
     }
 
     @Override
     public CommonResult start(Long id) {
-        TargetDetection entity = getById(id);
-
-        SysOssVo inputOssEntity = ossService.getById(entity.getInputOssId());
-
-        String filePath = inputOssEntity.getFileName();
-        String localPath = TaaisConfig.getProfile();
-        String resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
-
-        String fileName = StringUtils.substringAfterLast(filePath, "/");
-        String fileName_without_suffix = removeFileExtension(fileName);
-
-        Path path = Paths.get(resourcePath);
-        Path inputPath = path.resolveSibling(fileName_without_suffix + BizConstant.UNZIP_SUFFIX);
-        Path outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TARGET_DETECTION_SUFFIX);
-
-//        makeDir(inputPath.toString());
-        makeDir(outputPath.toString());
-
-        File file = new File(inputPath.toString());
-        if (!file.exists()) {
-            ZipUtils.unzip(resourcePath, inputPath.toString());
-        }
-
-        entity.setInputPath(inputPath.toString());
-        entity.setOutputPath(outputPath.toString());
-
+       TargetDetection entity = getById(id);
         entity.setStartTime(new Date());
 
-        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
-        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
-
-        StartTaskConfig startTaskConfig = new StartTaskConfig();
-        startTaskConfig.setBizType(BizConstant.BizType.TARGET_DETECTION);
-        startTaskConfig.setBizId(entity.getId());
-
-        startTaskConfig.setOtherParams(algorithmConfigTrack.getParameterConfig());
-
-        startTaskConfig.setSource_dir(entity.getInputPath());
-        startTaskConfig.setResult_dir(entity.getOutputPath());
-        startTaskConfig.setLog_path(getLogFilePath(entity));
-
-        if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
-            String modelPath = algorithmModelTrack.getModelAddress();
-            startTaskConfig.setModel_path(modelPath);
-        } else if (BizConstant.AlgorithmType.TRAIN.equals(algorithmConfigTrack.getType())) {
-            // 这时候需要在之前的基础上,添加类别的路径
-            File dir = new File(startTaskConfig.getSource_dir());
-            File[] files = dir.listFiles();
-            // 暂时只传第一个文件夹用作训练
-            if (files == null || files.length == 0) {
-                return CommonResult.fail("输入数据集为空!");
-            }
-            for (File file__ : files) {
-                if (file__.isDirectory()) {
-                    startTaskConfig.setSource_dir(file__.getAbsolutePath()); // 设置路径
-                    break;
-                }
-            }
-        }
+        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(entity.getAlgorithmId());
+        Map<String, Object> startTaskConfig = JsonUtils.parseMap(entity.getAlgorithmParameters());
 
-        HttpResponseEntity responseEntity = sendPostMsg(algorithmConfigTrack.getAlgorithmAddress(), startTaskConfig);
+        HttpResponseEntity responseEntity = sendPostMsg(algorithmConfigTrack.getStartApi(), startTaskConfig);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
             entity.setStatus(BizConstant.VideoStatus.RUNNING);
             updateById(entity);
@@ -334,7 +335,9 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
         startTaskConfig.setBizType(BizConstant.BizType.TO_INFRARED);
         startTaskConfig.setBizId(entity.getId());
 
-        HttpResponseEntity responseEntity = sendPostMsg(task_stop_url, startTaskConfig);
+        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(entity.getAlgorithmId());
+
+        HttpResponseEntity responseEntity = sendPostMsg(algorithmConfigTrack.getTerminateApi(), startTaskConfig);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
             entity.setStatus(BizConstant.VideoStatus.INTERRUPTED);
             updateById(entity);
@@ -430,7 +433,7 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
     @Override
     public CommonResult getLog(Long id) {
         TargetDetection entity = getById(id);
-        String logPath = getLogFilePath(entity);
+        String logPath = entity.getLogPath();
         System.out.println(logPath);
         File file = new File(logPath);
         if (!file.exists()) {