Jelajahi Sumber

feat: CAT模型预测对接完成

WANGKANG 9 bulan lalu
induk
melakukan
6b3a78250e

+ 40 - 6
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -153,8 +153,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
                     entity.setSubsystem(config.getSubsystem());
                     entity.setAlgorithmName(config.getAlgorithmName());
                 }
-            }
-            else {
+            } else {
                 AlgorithmModelTrackVo model = algorithmModelTrackService.selectById(modelId);
                 if (ObjectUtil.isNotNull(model)) {
                     AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(model.getAlgorithmId());
@@ -263,6 +262,11 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 
     @Override
     public CommonResult start(Long id) {
+        /*
+         * WANGKANG 望维护此代码的后来者安息。
+         *
+         * 不是我想写这么恶心,只是没办法,算法端的逻辑写的跟屎一样。。。。
+         */
         TrackSequence entity = getById(id);
         AlgorithmModelTrack algorithmModelTrack = algorithmModelTrackService.getById(entity.getAlgorithmModelId());
         AlgorithmConfigTrack algorithmConfigTrack = null;
@@ -322,6 +326,33 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
             if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
                 String modelPath = algorithmModelTrack.getModelAddress();
                 startTaskConfig.setModel_path(modelPath);
+                if (entity.getName().startsWith(CAT)) {
+                    File file________ = new File(startTaskConfig.getSource_dir());
+                    if (!file________.exists()) {
+                        return CommonResult.fail("数据集为空!");
+                    }
+                    if (file________.listFiles() != null && file________.listFiles().length > 0) {
+                        for (File file___________________tmp : file________.listFiles()) {
+                            if (file___________________tmp.isDirectory()) {
+                                startTaskConfig.setSource_dir(file___________________tmp.getPath());
+                                break;
+                            }
+                        }
+                    }
+
+                    File file___________________________ = new File(startTaskConfig.getModel_path());
+                    if (!file___________________________.exists()) {
+                        return CommonResult.fail("模型不存在!");
+                    }
+                    if (file___________________________.listFiles() != null && file___________________________.listFiles().length > 0) {
+                        for (File file___________________tmp : file___________________________.listFiles()) {
+                            if (file___________________tmp.isDirectory()) {
+                                startTaskConfig.setModel_path(file___________________tmp.getPath());
+                                break;
+                            }
+                        }
+                    }
+                }
 //            if (algorithmModelTrack.getModelName().startsWith("masc") || algorithmModelTrack.getModelName().startsWith("MASC")) {
 //                String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName().substring(5);
 //                startTaskConfig.setModel_path(modelPath);
@@ -331,14 +362,17 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
 //            } else {
 //                return CommonResult.fail("模型命名失败,请以MASC或CAT开头命名模型");
 //            }
-            }
-            else if(BizConstant.AlgorithmType.TRAIN.equals(algorithmConfigTrack.getType()) && entity.getName().startsWith(CAT)) {
+            } else if (BizConstant.AlgorithmType.TRAIN.equals(algorithmConfigTrack.getType()) && entity.getName().startsWith(CAT)) {
                 File file________ = new File(startTaskConfig.getSource_dir());
                 if (!file________.exists()) {
                     return CommonResult.fail("数据集为空!");
                 }
-                if(file________.listFiles()!=null && file________.listFiles().length>0) {
-                    startTaskConfig.setSource_dir(startTaskConfig.getSource_dir() + File.separator + file________.listFiles()[0].getName());
+                if (file________.listFiles() != null && file________.listFiles().length > 0) {
+                    for (File file___________________tmp : file________.listFiles()) {
+                        if (file___________________tmp.isDirectory()) {
+                            startTaskConfig.setSource_dir(file___________________tmp.getPath());
+                        }
+                    }
                 }
             }
         }