Răsfoiți Sursa

feat: MASC预测接口对接完成

WANGKANG 6 luni în urmă
părinte
comite
d9ca06635f

+ 58 - 32
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -6,9 +6,11 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.Arrays;
 import java.util.Date;
+import java.util.HashMap;
 import java.util.List;
 
 import cn.hutool.core.util.ObjectUtil;
+import com.alibaba.fastjson2.JSON;
 import com.mybatisflex.core.paginate.Page;
 import com.mybatisflex.core.query.QueryWrapper;
 import com.taais.biz.constant.BizConstant;
@@ -227,9 +229,9 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     @Override
     public CommonResult start(Long id) {
-        TrackSequence trackSequence = getById(id);
+        TrackSequence entity = getById(id);
 
-        SysOssVo inputOssEntity = ossService.getById(trackSequence.getInputOssId());
+        SysOssVo inputOssEntity = ossService.getById(entity.getInputOssId());
 
         String filePath = inputOssEntity.getFileName();
         String localPath = TaaisConfig.getProfile();
@@ -245,49 +247,58 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         makeDir(inputPath.toString());
         makeDir(outputPath.toString());
 
-        ZipUtils.unzip(resourcePath, inputPath.toString());
+        File file = new File(resourcePath);
+        if (!file.exists()) {
+            ZipUtils.unzip(resourcePath, inputPath.toString());
+        }
 
-        trackSequence.setInputPath(inputPath.toString());
-        trackSequence.setOutputPath(outputPath.toString());
+        entity.setInputPath(inputPath.toString());
+        entity.setOutputPath(outputPath.toString());
 
-        trackSequence.setStartTime(new Date());
+        entity.setStartTime(new Date());
 
-        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(trackSequence.getAlgorithmModelId());
+        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
         AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
 
         StartToInfraredTask startToInfraredTask = new StartToInfraredTask();
-        startToInfraredTask.setBizId(trackSequence.getId());
-
-        if (algorithmConfigTrack.getType() == BizConstant.AlgorithmType.REASONING) {
-            startToInfraredTask.setModel_path(algorithmModelTrack.getModelAddress());
-        }
+        startToInfraredTask.setBizType(BizConstant.BizType.TRACK_SEQUENCE);
+        startToInfraredTask.setBizId(entity.getId());
 
         startToInfraredTask.setOtherParams(algorithmConfigTrack.getParameterConfig());
-        startToInfraredTask.setSource_dir(trackSequence.getInputPath());
-        startToInfraredTask.setResult_dir(trackSequence.getOutputPath());
 
-        startToInfraredTask.setBizType(BizConstant.BizType.TRACK_SEQUENCE);
+        startToInfraredTask.setSource_dir(entity.getInputPath());
+        startToInfraredTask.setResult_dir(entity.getOutputPath());
+
+        if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
+            String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName();
+            startToInfraredTask.setModel_path(modelPath);
+        }
+
 
         HttpResponseEntity responseEntity = sendPostMsg(algorithmConfigTrack.getAlgorithmAddress(), startToInfraredTask);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
-            trackSequence.setStatus(BizConstant.VideoStatus.RUNNING);
-            updateById(trackSequence);
+            entity.setStatus(BizConstant.VideoStatus.RUNNING);
+            updateById(entity);
             return CommonResult.success("任务开始成功,请等待完成");
         } else {
-            trackSequence.setStatus(BizConstant.VideoStatus.FAILED);
-            updateById(trackSequence);
+            entity.setStatus(BizConstant.VideoStatus.FAILED);
+            updateById(entity);
             return CommonResult.fail("任务开始失败,请检查!");
         }
     }
 
     @Override
     public CommonResult stop(Long id) {
-        TrackSequence trackSequence = getById(id);
+        TrackSequence entity = getById(id);
+
+        StartToInfraredTask startToInfraredTask = new StartToInfraredTask();
+        startToInfraredTask.setBizType(BizConstant.BizType.TRACK_SEQUENCE);
+        startToInfraredTask.setBizId(entity.getId());
 
-        HttpResponseEntity responseEntity = sendPostMsg(task_stop_url, trackSequence);
+        HttpResponseEntity responseEntity = sendPostMsg(task_stop_url, startToInfraredTask);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
-            trackSequence.setStatus(BizConstant.VideoStatus.INTERRUPTED);
-            updateById(trackSequence);
+            entity.setStatus(BizConstant.VideoStatus.INTERRUPTED);
+            updateById(entity);
             return CommonResult.success("终止任务成功");
         } else {
             return CommonResult.fail("终止任务失败");
@@ -296,21 +307,22 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     @Override
     public ResponseEntity<org.springframework.core.io.Resource> zipImages(Long id) {
-        TrackSequence trackSequence = this.getById(id);
-        if (ObjectUtil.isNull(trackSequence)) {
+        TrackSequence entity = this.getById(id);
+        if (ObjectUtil.isNull(entity)) {
             return ResponseEntity.status(HttpStatus.NOT_FOUND).body(null);
         }
 
-        String outputPath = trackSequence.getOutputPath();
+        String outputPath = entity.getOutputPath();
         String zipFilePath = outputPath + ".zip";
 
-        try {
-            ZipUtils.zipFolderFiles(outputPath, zipFilePath);
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
-
         File file = new File(zipFilePath);
+        if (!file.exists()) {
+            try {
+                ZipUtils.zipFolderFiles(outputPath, zipFilePath);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
 
         if (!file.exists() || !file.isFile()) {
             return ResponseEntity.status(HttpStatus.NOT_FOUND).body(null);
@@ -338,6 +350,20 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             entity.setCostSecond(null);
         }
         updateById(entity);
+
+        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
+        algorithmModelTrack.setModelStatus("200".equals(status) ? BizConstant.ModelStatus.END : BizConstant.ModelStatus.FAILED);
+
+        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
+        String params =  algorithmConfigTrack.getParameterConfig();
+        HashMap<String, Object> parse = (HashMap<String, Object>) JSON.parse((params));
+
+        if("200".equals(status) && ObjectUtil.isNull(algorithmModelTrack.getModelAddress())) {
+            algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator + ((HashMap<String, String>)parse.get("dataset")).get("classes"));
+
+            System.out.println(parse.get("dataset"));
+            algorithmModelTrackService.updateById(algorithmModelTrack);
+        }
         return CommonResult.success();
     }
 }