Przeglądaj źródła

feat: CAT训练、评估、预测接口对接完成

WANGKANG 8 miesięcy temu
rodzic
commit
be71cf832c

+ 20 - 10
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -270,8 +270,17 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         startToInfraredTask.setResult_dir(entity.getOutputPath());
 
         if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
-            String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName();
-            startToInfraredTask.setModel_path(modelPath);
+            if(algorithmModelTrack.getModelName().startsWith("masc") || algorithmModelTrack.getModelName().startsWith("MASC")) {
+                String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName();
+                startToInfraredTask.setModel_path(modelPath);
+            }
+            else if(algorithmModelTrack.getModelName().startsWith("cat") || algorithmModelTrack.getModelName().startsWith("CAT")) {
+                String modelPath = algorithmModelTrack.getModelAddress();
+                startToInfraredTask.setModel_path(modelPath);
+            }
+            else {
+                return CommonResult.fail("模型命名失败,请以MASC或CAT开头命名模型");
+            }
         }
 
 
@@ -318,15 +327,16 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
         AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
 
-        if(BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())){
+        if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
             outputPath = entity.getOutputPath() + File.separator + "predict";
             zipFilePath = outputPath + ".zip";
-        }
-        else if(BizConstant.AlgorithmType.TEST.equals(algorithmConfigTrack.getType())){
+        } else if (BizConstant.AlgorithmType.TEST.equals(algorithmConfigTrack.getType())) {
             outputPath = entity.getOutputPath() + File.separator + "evaluate";
             zipFilePath = outputPath + ".zip";
-        }
-        else {
+        } else if (BizConstant.AlgorithmType.TRAIN.equals(algorithmConfigTrack.getType())) {
+            outputPath = entity.getOutputPath();
+            zipFilePath = outputPath + ".zip";
+        } else {
             System.out.println("未知算法类型!");
             return ResponseEntity.status(HttpStatus.NOT_FOUND).body(null);
         }
@@ -371,11 +381,11 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         algorithmModelTrack.setModelStatus("200".equals(status) ? BizConstant.ModelStatus.END : BizConstant.ModelStatus.FAILED);
 
         AlgorithmConfigTrack algorithmConfigTrack = algorithmConfigTrackService.getById(algorithmModelTrack.getAlgorithmId());
-        String params =  algorithmConfigTrack.getParameterConfig();
+        String params = algorithmConfigTrack.getParameterConfig();
         HashMap<String, Object> parse = (HashMap<String, Object>) JSON.parse((params));
 
-        if("200".equals(status) && ObjectUtil.isNull(algorithmModelTrack.getModelAddress())) {
-            algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator + ((HashMap<String, String>)parse.get("dataset")).get("classes"));
+        if ("200".equals(status) && ObjectUtil.isNull(algorithmModelTrack.getModelAddress())) {
+            algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator + ((HashMap<String, String>) parse.get("dataset")).get("classes"));
 
             System.out.println(parse.get("dataset"));
             algorithmModelTrackService.updateById(algorithmModelTrack);