Эх сурвалжийг харах

feat: 目标检测训练接口对接完成

WANGKANG 8 сар өмнө
parent
commit
0f630dd57a

+ 3 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/PublicController.java

@@ -111,13 +111,13 @@ public class PublicController extends BaseController {
 
     @PostMapping("/task/get_result")
     public CommonResult getResult(@Valid @RequestBody TaskTrackResultBo taskTrackResultBo) {
-        if(taskTrackResultBo.getBizType().equals(BizConstant.BizType.TO_INFRARED)) {
+        if(BizConstant.BizType.TO_INFRARED.equals(taskTrackResultBo.getBizType())) {
             return toInfraredService.getResult(taskTrackResultBo);
         }
-        else if(taskTrackResultBo.getBizType().equals(BizConstant.BizType.TRACK_SEQUENCE)) {
+        else if(BizConstant.BizType.TRACK_SEQUENCE.equals(taskTrackResultBo.getBizType())) {
             return trackSequenceService.getResult(taskTrackResultBo);
         }
-        else if(taskTrackResultBo.getBizType().equals(BizConstant.BizType.TARGET_DETECTION)) {
+        else if(BizConstant.BizType.TARGET_DETECTION.equals(taskTrackResultBo.getBizType())) {
             return targetDetectionService.getResult(taskTrackResultBo);
         }
         else {

+ 39 - 30
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetDetectionServiceImpl.java

@@ -55,8 +55,8 @@ import static com.taais.biz.service.impl.VideoStableServiceImpl.sendPostMsg;
  */
 @Service
 public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionMapper, TargetDetection> implements ITargetDetectionService {
-    @Value("${server.target_detection_stop_url}")
-    private String target_detection_stop_url;
+    @Value("${server.task_stop_url}")
+    private String task_stop_url;
 
     @Resource
     private TargetDetectionMapper targetDetectionMapper;
@@ -229,9 +229,9 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
 
     @Override
     public CommonResult start(Long id) {
-        TargetDetection targetDetection = getById(id);
+        TargetDetection entity = getById(id);
 
-        SysOssVo inputOssEntity = ossService.getById(targetDetection.getInputOssId());
+        SysOssVo inputOssEntity = ossService.getById(entity.getInputOssId());
 
         String filePath = inputOssEntity.getFileName();
         String localPath = TaaisConfig.getProfile();
@@ -247,49 +247,57 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
         makeDir(inputPath.toString());
         makeDir(outputPath.toString());
 
-        ZipUtils.unzip(resourcePath, inputPath.toString());
+        File file = new File(resourcePath);
+        if (!file.exists()) {
+            ZipUtils.unzip(resourcePath, inputPath.toString());
+        }
 
-        targetDetection.setInputPath(inputPath.toString());
-        targetDetection.setOutputPath(outputPath.toString());
+        entity.setInputPath(inputPath.toString());
+        entity.setOutputPath(outputPath.toString());
 
-        targetDetection.setStartTime(new Date());
+        entity.setStartTime(new Date());
 
-        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(targetDetection.getAlgorithmModelId());
+        AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
         AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
 
         StartToInfraredTask startToInfraredTask = new StartToInfraredTask();
-        startToInfraredTask.setBizId(targetDetection.getId());
+        startToInfraredTask.setBizType(BizConstant.BizType.TARGET_DETECTION);
+        startToInfraredTask.setBizId(entity.getId());
+
+        startToInfraredTask.setOtherParams(algorithmConfigTrack.getParameterConfig());
 
-        if (algorithmConfigTrack.getType() == BizConstant.AlgorithmType.REASONING) {
+        startToInfraredTask.setSource_dir(entity.getInputPath());
+        startToInfraredTask.setResult_dir(entity.getOutputPath());
+
+        if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
             startToInfraredTask.setModel_path(algorithmModelTrack.getModelAddress());
         }
 
-        startToInfraredTask.setOtherParams(algorithmConfigTrack.getParameterConfig());
-        startToInfraredTask.setSource_dir(targetDetection.getInputPath());
-        startToInfraredTask.setResult_dir(targetDetection.getOutputPath());
-
-        startToInfraredTask.setBizType(BizConstant.BizType.TARGET_DETECTION);
 
         HttpResponseEntity responseEntity = sendPostMsg(algorithmConfigTrack.getAlgorithmAddress(), startToInfraredTask);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
-            targetDetection.setStatus(BizConstant.VideoStatus.RUNNING);
-            updateById(targetDetection);
+            entity.setStatus(BizConstant.VideoStatus.RUNNING);
+            updateById(entity);
             return CommonResult.success("任务开始成功,请等待完成");
         } else {
-            targetDetection.setStatus(BizConstant.VideoStatus.FAILED);
-            updateById(targetDetection);
+            entity.setStatus(BizConstant.VideoStatus.FAILED);
+            updateById(entity);
             return CommonResult.fail("任务开始失败,请检查!");
         }
     }
 
     @Override
     public CommonResult stop(Long id) {
-        TargetDetection targetDetection = getById(id);
+        TargetDetection entity = getById(id);
+
+        StartToInfraredTask startToInfraredTask = new StartToInfraredTask();
+        startToInfraredTask.setBizType(BizConstant.BizType.TO_INFRARED);
+        startToInfraredTask.setBizId(entity.getId());
 
-        HttpResponseEntity responseEntity = sendPostMsg(target_detection_stop_url, targetDetection);
+        HttpResponseEntity responseEntity = sendPostMsg(task_stop_url, startToInfraredTask);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
-            targetDetection.setStatus(BizConstant.VideoStatus.INTERRUPTED);
-            updateById(targetDetection);
+            entity.setStatus(BizConstant.VideoStatus.INTERRUPTED);
+            updateById(entity);
             return CommonResult.success("终止任务成功");
         } else {
             return CommonResult.fail("终止任务失败");
@@ -324,13 +332,14 @@ public class TargetDetectionServiceImpl extends BaseServiceImpl<TargetDetectionM
         String outputPath = targetDetection.getOutputPath();
         String zipFilePath = outputPath + ".zip";
 
-        try {
-            ZipUtils.zipFolderFiles(outputPath, zipFilePath);
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
-
         File file = new File(zipFilePath);
+        if (!file.exists()) {
+            try {
+                ZipUtils.zipFolderFiles(outputPath, zipFilePath);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
 
         if (!file.exists() || !file.isFile()) {
             return ResponseEntity.status(HttpStatus.NOT_FOUND).body(null);

+ 3 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/ToInfraredServiceImpl.java

@@ -56,8 +56,8 @@ import static com.taais.biz.service.impl.VideoStableServiceImpl.*;
 @Service
 @Log4j2
 public class ToInfraredServiceImpl extends BaseServiceImpl<ToInfraredMapper, ToInfrared> implements IToInfraredService {
-    @Value("${server.to_infrared_stop_url}")
-    private String to_infrared_stop_url;
+    @Value("${server.task_stop_url}")
+    private String task_stop_url;
 
     @Autowired
     private ISysOssService ossService;
@@ -300,7 +300,7 @@ public class ToInfraredServiceImpl extends BaseServiceImpl<ToInfraredMapper, ToI
         startToInfraredTask.setBizType(BizConstant.BizType.TO_INFRARED);
         startToInfraredTask.setBizId(toInfrared.getId());
 
-        HttpResponseEntity responseEntity = sendPostMsg(to_infrared_stop_url, startToInfraredTask);
+        HttpResponseEntity responseEntity = sendPostMsg(task_stop_url, startToInfraredTask);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
             toInfrared.setStatus(BizConstant.VideoStatus.INTERRUPTED);
             updateById(toInfrared);

+ 3 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -54,8 +54,8 @@ import static com.taais.biz.service.impl.VideoStableServiceImpl.makeDir;
  */
 @Service
 public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMapper, TrackSequence> implements ITrackSequenceService {
-    @Value("${server.track_sequence_stop_url}")
-    private String track_sequence_stop_url;
+    @Value("${server.task_stop_url}")
+    private String task_stop_url;
 
     @Autowired
     private AlgorithmConfigTrackServiceImpl algorithmConfigTrackService;
@@ -284,7 +284,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
     public CommonResult stop(Long id) {
         TrackSequence trackSequence = getById(id);
 
-        HttpResponseEntity responseEntity = sendPostMsg(track_sequence_stop_url, trackSequence);
+        HttpResponseEntity responseEntity = sendPostMsg(task_stop_url, trackSequence);
         if (ObjectUtil.isNotNull(responseEntity) && responseEntity.getStatus() == 200) {
             trackSequence.setStatus(BizConstant.VideoStatus.INTERRUPTED);
             updateById(trackSequence);