WANGKANG 8 сар өмнө
parent
commit
a25d7b5178

+ 38 - 24
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -33,6 +33,7 @@ import com.taais.common.orm.core.service.impl.BaseServiceImpl;
 import com.taais.system.domain.vo.SysOssVo;
 import com.taais.system.service.ISysOssService;
 import jakarta.annotation.Resource;
+import org.springframework.beans.BeanUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
 import org.springframework.core.io.FileSystemResource;
@@ -144,13 +145,21 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         Page<TrackSequenceVo> page = this.pageAs(PageQuery.build(), queryWrapper, TrackSequenceVo.class);
         page.getRecords().forEach(entity -> {
             Long modelId = entity.getAlgorithmModelId();
-            AlgorithmModelTrackVo model = algorithmModelTrackService.selectById(modelId);
-            if (ObjectUtil.isNotNull(model)) {
-                AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(model.getAlgorithmId());
-                entity.setType(config.getType());
-                entity.setSubsystem(config.getSubsystem());
-                entity.setAlgorithmName(config.getAlgorithmName());
-                entity.setModelName(model.getModelName());
+            Long algorithmId = entity.getAlgorithmId();
+            if (modelId != null) {
+                AlgorithmModelTrackVo model = algorithmModelTrackService.selectById(modelId);
+                if (ObjectUtil.isNotNull(model)) {
+                    entity.setModelName(model.getModelName());
+                }
+
+            }
+            if (algorithmId != null) {
+                AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(algorithmId);
+                if (ObjectUtil.isNotNull(config)) {
+                    entity.setType(config.getType());
+                    entity.setSubsystem(config.getSubsystem());
+                    entity.setAlgorithmName(config.getAlgorithmName());
+                }
             }
         });
         return PageResult.build(page);
@@ -165,23 +174,33 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
     @Override
     public CommonResult insert(TrackSequenceBo trackSequenceBo) {
         // 检查任务名称
-        if(ObjectUtil.isEmpty(trackSequenceBo.getName()) || (!trackSequenceBo.getName().startsWith("MASC") && !trackSequenceBo.getName().startsWith("CAT"))){
+        if (ObjectUtil.isEmpty(trackSequenceBo.getName()) || (!trackSequenceBo.getName().startsWith("MASC") && !trackSequenceBo.getName().startsWith("CAT"))) {
             return CommonResult.fail("任务命名错误,需以MASC或CAT开头!");
         }
 
         // 检查input_oss_id是否存在
         if (ObjectUtil.isNull(trackSequenceBo.getInputOssId())) {
-            return  CommonResult.fail("请上传模型");
+            return CommonResult.fail("请上传模型");
         }
 
         SysOssVo ossEntity = ossService.getById(trackSequenceBo.getInputOssId());
         if (ObjectUtil.isNull(ossEntity)) {
-            return  CommonResult.fail("找不到指定模型!");
+            return CommonResult.fail("找不到指定模型!");
+        }
+
+        if (ObjectUtil.isEmpty(trackSequenceBo.getAlgorithmId())) {
+            return CommonResult.fail("请指定算法!");
+        }
+
+        AlgorithmConfigTrackVo config = algorithmConfigTrackService.selectById(trackSequenceBo.getAlgorithmId());
+        if (ObjectUtil.isNull(config)) {
+            return CommonResult.fail("找不到指定的算法!");
         }
 
         TrackSequence trackSequence = new TrackSequence();
 
-        trackSequence.setInputOssId(trackSequenceBo.getInputOssId());
+        BeanUtils.copyProperties(trackSequenceBo, trackSequence);
+
         trackSequence.setUrl(ossEntity.getUrl());
 
         String filePath = ossEntity.getFileName();
@@ -197,18 +216,13 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         trackSequence.setOutputPath(outPath.toString());
 
         trackSequence.setZipFilePath(path.resolveSibling(fileName_without_suffix + ".zip").toString());
-
-        trackSequence.setName(trackSequenceBo.getName());
         trackSequence.setStatus(NOT_START);
-        trackSequence.setRemarks(trackSequenceBo.getRemarks());
 
-        trackSequence.setAlgorithmModelId(trackSequenceBo.getAlgorithmModelId());
 
-        boolean __  = this.save(trackSequence);// 使用全局配置的雪花算法主键生成器生成ID值
-        if(__) {
+        boolean __ = this.save(trackSequence);// 使用全局配置的雪花算法主键生成器生成ID值
+        if (__) {
             return CommonResult.success();
-        }
-        else {
+        } else {
             return CommonResult.fail();
         }
     }
@@ -285,15 +299,13 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         startTaskConfig.setLog_path(entity.getOutputPath() + File.separator + getLogFileName(entity));
 
         if (BizConstant.AlgorithmType.REASONING.equals(algorithmConfigTrack.getType())) {
-            if(algorithmModelTrack.getModelName().startsWith("masc") || algorithmModelTrack.getModelName().startsWith("MASC")) {
+            if (algorithmModelTrack.getModelName().startsWith("masc") || algorithmModelTrack.getModelName().startsWith("MASC")) {
                 String modelPath = algorithmModelTrack.getModelAddress() + File.separator + algorithmModelTrack.getModelName().substring(5);
                 startTaskConfig.setModel_path(modelPath);
-            }
-            else if(algorithmModelTrack.getModelName().startsWith("cat") || algorithmModelTrack.getModelName().startsWith("CAT")) {
+            } else if (algorithmModelTrack.getModelName().startsWith("cat") || algorithmModelTrack.getModelName().startsWith("CAT")) {
                 String modelPath = algorithmModelTrack.getModelAddress();
                 startTaskConfig.setModel_path(modelPath);
-            }
-            else {
+            } else {
                 return CommonResult.fail("模型命名失败,请以MASC或CAT开头命名模型");
             }
         }
@@ -407,9 +419,11 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         }
         return CommonResult.success();
     }
+
     private String getLogFileName(TrackSequence entity) {
         return entity.getId() + BizConstant.TRACK_SEQUENCE_SUFFIX + ".log";
     }
+
     @Override
     public CommonResult getLog(Long id) {
         TrackSequence entity = getById(id);