Bladeren bron

feat: MASC预测对接

WANGKANG 9 maanden geleden
bovenliggende
commit
6ea045fb2f

+ 37 - 47
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -50,6 +50,7 @@ import com.taais.biz.service.ITrackSequenceService;
 import static com.taais.biz.constant.BizConstant.VideoStatus.NOT_START;
 import static com.taais.biz.domain.table.AlgorithmModelTrackTableDef.ALGORITHM_MODEL_TRACK;
 import static com.taais.biz.domain.table.TrackSequenceTableDef.TRACK_SEQUENCE;
+import static com.taais.biz.service.impl.ToInfraredServiceImpl.readLogContent;
 import static com.taais.biz.service.impl.VideoStableServiceImpl.*;
 import static com.taais.biz.service.impl.VideoStableServiceImpl.makeDir;
 
@@ -82,30 +83,18 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     private QueryWrapper buildQueryWrapper(TrackSequenceBo trackSequenceBo) {
         QueryWrapper queryWrapper = super.buildBaseQueryWrapper();
-        queryWrapper.and(TRACK_SEQUENCE.NAME.like
-            (trackSequenceBo.getName()));
-        queryWrapper.and(TRACK_SEQUENCE.STATUS.eq
-            (trackSequenceBo.getStatus()));
-        queryWrapper.and(TRACK_SEQUENCE.START_TIME.eq
-            (trackSequenceBo.getStartTime()));
-        queryWrapper.and(TRACK_SEQUENCE.END_TIME.eq
-            (trackSequenceBo.getEndTime()));
-        queryWrapper.and(TRACK_SEQUENCE.COST_SECOND.eq
-            (trackSequenceBo.getCostSecond()));
-        queryWrapper.and(TRACK_SEQUENCE.LOG.eq
-            (trackSequenceBo.getLog()));
-        queryWrapper.and(TRACK_SEQUENCE.REMARKS.eq
-            (trackSequenceBo.getRemarks()));
-        queryWrapper.and(TRACK_SEQUENCE.URL.eq
-            (trackSequenceBo.getUrl()));
-        queryWrapper.and(TRACK_SEQUENCE.INPUT_OSS_ID.eq
-            (trackSequenceBo.getInputOssId()));
-        queryWrapper.and(TRACK_SEQUENCE.INPUT_PATH.eq
-            (trackSequenceBo.getInputPath()));
-        queryWrapper.and(TRACK_SEQUENCE.OUTPUT_PATH.eq
-            (trackSequenceBo.getOutputPath()));
-        queryWrapper.and(TRACK_SEQUENCE.ZIP_FILE_PATH.eq
-            (trackSequenceBo.getZipFilePath()));
+        queryWrapper.and(TRACK_SEQUENCE.NAME.like(trackSequenceBo.getName()));
+        queryWrapper.and(TRACK_SEQUENCE.STATUS.eq(trackSequenceBo.getStatus()));
+        queryWrapper.and(TRACK_SEQUENCE.START_TIME.eq(trackSequenceBo.getStartTime()));
+        queryWrapper.and(TRACK_SEQUENCE.END_TIME.eq(trackSequenceBo.getEndTime()));
+        queryWrapper.and(TRACK_SEQUENCE.COST_SECOND.eq(trackSequenceBo.getCostSecond()));
+        queryWrapper.and(TRACK_SEQUENCE.LOG.eq(trackSequenceBo.getLog()));
+        queryWrapper.and(TRACK_SEQUENCE.REMARKS.eq(trackSequenceBo.getRemarks()));
+        queryWrapper.and(TRACK_SEQUENCE.URL.eq(trackSequenceBo.getUrl()));
+        queryWrapper.and(TRACK_SEQUENCE.INPUT_OSS_ID.eq(trackSequenceBo.getInputOssId()));
+        queryWrapper.and(TRACK_SEQUENCE.INPUT_PATH.eq(trackSequenceBo.getInputPath()));
+        queryWrapper.and(TRACK_SEQUENCE.OUTPUT_PATH.eq(trackSequenceBo.getOutputPath()));
+        queryWrapper.and(TRACK_SEQUENCE.ZIP_FILE_PATH.eq(trackSequenceBo.getZipFilePath()));
 
         return queryWrapper;
     }
@@ -271,7 +260,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
         Path path = Paths.get(resourcePath);
         Path inputPath = path.resolveSibling(fileName_without_suffix + BizConstant.UNZIP_SUFFIX);
-        Path outputPath = path.resolveSibling(fileName_without_suffix + BizConstant.TRACK_SEQUENCE_SUFFIX);
+        Path outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX);
 
 //        makeDir(inputPath.toString());
         makeDir(outputPath.toString());
@@ -287,7 +276,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         entity.setStartTime(new Date());
 
         AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
-        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
+        AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(entity.getAlgorithmId());
 
         StartTaskConfig startTaskConfig = new StartTaskConfig();
         startTaskConfig.setBizType(BizConstant.BizType.TRACK_SEQUENCE);
@@ -300,15 +289,17 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         startTaskConfig.setLog_path(entity.getOutputPath() + File.separator + getLogFileName(entity));
 
         if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
-            if (algorithmModelTrack.getModelName().startsWith("masc") || algorithmModelTrack.getModelName().startsWith("MASC")) {
-                String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName().substring(5);
-                startTaskConfig.setModel_path(modelPath);
-            } else if (algorithmModelTrack.getModelName().startsWith("cat") || algorithmModelTrack.getModelName().startsWith("CAT")) {
-                String modelPath = algorithmModelTrack.getModelAddress();
-                startTaskConfig.setModel_path(modelPath);
-            } else {
-                return CommonResult.fail("模型命名失败,请以MASC或CAT开头命名模型");
-            }
+            String modelPath = algorithmModelTrack.getModelAddress();
+            startTaskConfig.setModel_path(modelPath);
+//            if (algorithmModelTrack.getModelName().startsWith("masc") || algorithmModelTrack.getModelName().startsWith("MASC")) {
+//                String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName().substring(5);
+//                startTaskConfig.setModel_path(modelPath);
+//            } else if (algorithmModelTrack.getModelName().startsWith("cat") || algorithmModelTrack.getModelName().startsWith("CAT")) {
+//                String modelPath = algorithmModelTrack.getModelAddress();
+//                startTaskConfig.setModel_path(modelPath);
+//            } else {
+//                return CommonResult.fail("模型命名失败,请以MASC或CAT开头命名模型");
+//            }
         }
 
 
@@ -357,9 +348,17 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
         if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
             outputPath = entity.getOutputPath() + File.separator + "predict";
+            File file__ = new File(outputPath);
+            if (!file__.exists()) {
+                outputPath = entity.getOutputPath();
+            }
             zipFilePath = outputPath + ".zip";
         } else if (BizConstant.AlgorithmType.TEST.equals(algorithmConfigTrack.getType())) {
             outputPath = entity.getOutputPath() + File.separator + "evaluate";
+            File file__ = new File(outputPath);
+            if (!file__.exists()) {
+                outputPath = entity.getOutputPath();
+            }
             zipFilePath = outputPath + ".zip";
         } else if (BizConstant.AlgorithmType.TRAIN.equals(algorithmConfigTrack.getType())) {
             outputPath = entity.getOutputPath();
@@ -383,10 +382,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         }
 
         org.springframework.core.io.Resource resource = new FileSystemResource(file);
-        return ResponseEntity.ok()
-            .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + file.getName() + "\"")
-            .header(HttpHeaders.CONTENT_TYPE, "application/octet-stream")
-            .body(resource);
+        return ResponseEntity.ok().header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + file.getName() + "\"").header(HttpHeaders.CONTENT_TYPE, "application/octet-stream").body(resource);
     }
 
     @Override
@@ -430,18 +426,12 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         TrackSequence entity = getById(id);
         String outPutPath = entity.getOutputPath();
         String logPath = outPutPath + File.separator + getLogFileName(entity);
+        System.out.println(logPath);
         File file = new File(logPath);
         if (!file.exists()) {
             return CommonResult.fail("日志文件不存在!");
         }
-        try (BufferedReader br = new BufferedReader(new FileReader(logPath))) {
-            String log = Files.readString(Paths.get(logPath));
-            log = log.replaceAll("\n", "<br/>\n");
-            return CommonResult.success(log, "success");
-        } catch (Exception e) {
-            e.printStackTrace();
-            return CommonResult.fail("读取日志失败!");
-        }
+        return CommonResult.success(readLogContent(logPath), "success");
     }
 
     @Override