浏览代码

Merge branch 'dev_lsk_1103' of www/taais into develop

Sk18834839360 7 月之前
父节点
当前提交
72737b9279

+ 2 - 2
taais-modules/taais-biz/src/main/java/com/taais/biz/constant/BizConstant.java

@@ -120,9 +120,9 @@ public class BizConstant {
 
     public static final String TYPE_OBJ_MATCH = "OBJ_MATCH";
 
-    public static final String DOCKER_OBJ_MATCH_PATH = DOCKER_BASE_PATH + "/objectMatch";
+    public static final String DOCKER_OBJ_MATCH_PATH = "/objectMatch";
     public static final String TYPE_OBJ_TRACE = "OBJ_TRACE";
-    public static final String DOCKER_MAT_TASK = DOCKER_BASE_PATH + "/obj_track";
+    public static final String DOCKER_MAT_TASK = "/obj_track";
     public static final String MULTI_OBJ_TRACE_URL = "127.0.0.1:10027/objTrace";
     public static final String MULTI_OBJ_MATCH_URL = "127.0.0.1:10028/imgMatch";
 }

+ 6 - 6
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/ObjectMatchController.java

@@ -146,9 +146,9 @@ public class ObjectMatchController extends BaseController {
 
         params.put("bizId", String.valueOf(bo.getId()));
         params.put("bizType", TYPE_OBJ_MATCH);
-        params.put("logPath", bo.getResultPath() + "/log.log");
-        params.put("sourcePath", bo.getPreprocessPath());
-        params.put("resultPath", bo.getResultPath());
+        params.put("logPath", DOCKER_BASE_PATH + bo.getResultPath() + "/log.log");
+        params.put("sourcePath", DOCKER_BASE_PATH +  bo.getPreprocessPath());
+        params.put("resultPath", DOCKER_BASE_PATH + bo.getResultPath());
         params.put("otherParams", new JSONObject().toString());
 
         log.info("obj_match params: {}", params);
@@ -175,7 +175,7 @@ public class ObjectMatchController extends BaseController {
             return CommonResult.fail("未找到任务", null);
         }
         String path = bo.getResultPath();
-        File dir = new File(path + "/IR_VIS_obj_in_IR");
+        File dir = new File(BizConstant.DOCKER_BASE_PATH + path + "/IR_VIS_obj_in_IR");
         List<String> res = new ArrayList<>();
         if (dir.exists()) {
             for (File file : dir.listFiles()) {
@@ -201,10 +201,10 @@ public class ObjectMatchController extends BaseController {
 
         SysOssVo _file = sysOssService.getById(Long.parseLong(file));
         String filePath = TaaisConfig.getProfile() + _file.getUrl().split("/profile")[1];
-        ZipFileExtractor.extractZipFile(filePath, path);
+        ZipFileExtractor.extractZipFile(filePath, DOCKER_BASE_PATH + path);
 
         match.setResultPath(path + "/result");
-        File dir = new File(match.getResultPath());
+        File dir = new File(DOCKER_BASE_PATH + match.getResultPath());
         if (!dir.exists()) {
             dir.mkdirs();
         }

+ 4 - 4
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/ObjectTraceMergeController.java

@@ -106,7 +106,7 @@ public class ObjectTraceMergeController extends BaseController {
         }
         // 创建结果路径
         try {
-            File dir = new File(resultPath);
+            File dir = new File(DOCKER_BASE_PATH + resultPath);
             if (!dir.exists()) {
                 dir.mkdirs();
             }
@@ -138,9 +138,9 @@ public class ObjectTraceMergeController extends BaseController {
         Map<String, String> params = new HashMap<>();
         params.put("bizType", TYPE_OBJ_TRACE);
         params.put("bizId", String.valueOf(vo.getId()));
-        params.put("logPath", vo.getResultPath());
+        params.put("logPath", DOCKER_BASE_PATH + vo.getResultPath());
         params.put("sourcePath", vo.getPreprocessPath());
-        params.put("resultPath", vo.getResultPath());
+        params.put("resultPath", DOCKER_BASE_PATH +  vo.getResultPath());
         params.put("otherParams", new JSONObject().toString());
 
         log.info("obj_trace params: {}", params);
@@ -171,7 +171,7 @@ public class ObjectTraceMergeController extends BaseController {
 //            return CommonResult.fail("任务未成功执行!", new ArrayList<>());
 //        }
         try {
-            String res = bo.getResultPath();
+            String res = DOCKER_BASE_PATH + bo.getResultPath();
             File dir = new File(res);
             File[] files = dir.listFiles();
             List<String> resList = new ArrayList<>();

+ 13 - 8
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskServiceImpl.java

@@ -239,23 +239,28 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
             log.error(e.getMessage());
         }
 
+        if (hasModelProperty) {
+            AlgorithmModelVo bo = algorithmModelService.selectById(Long.valueOf(_modelId));
+            String path = bo.getModelAddress().replace("/profile", BizConstant.DOCKER_BASE_PATH);
+            algorithmRequestDto.getOtherParams().put("pretrained", true);
+            algorithmRequestDto.getOtherParams().put("pretrained_model", path);
+        }
+
         String taskName = detail.getName();
         if (taskName.contains("训练")) {
             log.info("train");
-            if (hasModelProperty) {
-                AlgorithmModelVo bo = algorithmModelService.selectById(Long.valueOf(_modelId));
-                String path = bo.getModelAddress().replace("/profile", BizConstant.DOCKER_BASE_PATH);
-                algorithmRequestDto.getOtherParams().put("pretrained", true);
-                algorithmRequestDto.getOtherParams().put("pretrained_model", path);
-            }
         } else if (taskName.contains("验证")) {
             String[] urls = url.split(";;;");
             url = urls[0];
-            algorithmRequestDto.getOtherParams().put("weight_path", BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + urls[1] + "/result/weights/best.pt");
+            if (urls.length > 1) {
+                algorithmRequestDto.getOtherParams().put("weight_path", BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + urls[1] + "/result/weights/best.pt");
+            }
         } else if (taskName.contains("测试")) {
             String[] urls = url.split(";;;");
             url = urls[0];
-            algorithmRequestDto.getOtherParams().put("weight_path", BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + urls[1] + "/result/weights/best.pt");
+            if (urls.length > 1) {
+                algorithmRequestDto.getOtherParams().put("weight_path", BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + urls[1] + "/result/weights/best.pt");
+            }
         } else {
             log.error("taskName error: " + taskName);
             return;

+ 24 - 4
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationTaskServiceImpl.java

@@ -451,6 +451,24 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
         return false;
     }
 
+    /**
+     * 检查是否有test集
+     * @param batch
+     */
+    private boolean hasTestSet(String batch) {
+        String[] batches = batch.split(",");
+
+        for (String batchNum : batches) {
+            List<DataVo> dataVoList = dataService.getDataByBatchNum(batchNum);
+            for (DataVo dataVo : dataVoList) {
+                if (StringUtils.isEmpty(dataVo.getLabelurl())) {
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
+
     private void createTestTask(Long taskId, CreateTargetIdentificationTaskDto taskDto, Map<String, String> records) {
         List<TaskDto> algTaskList = taskDto.getAlgTaskList();
         List<String> testBatchNumList = taskDto.getTestBatchNumList();
@@ -497,7 +515,7 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                     subtaskDetail.setName(algName + "_验证");
                     subtaskDetail.setParameters(params.get(1));
                     copyFilesToPath(batchNum, subtaskPath, true);
-                    subtaskDetail.setType(algorithmModelVo.getVerifyUrl() + ";;;" + records.get(algName));
+                    subtaskDetail.setType(algorithmModelVo.getVerifyUrl() + ";;;" + (records != null ? records.get(algName) : ""));
                     subtaskDetailsService.insert(subtaskDetail);
                     // reset to '测试'
                     subtaskPath = "/" + UUID.randomUUID().toString().replace("-", "_");
@@ -508,9 +526,11 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
 
                 subtaskDetail.setName(algName + "_测试");
                 subtaskDetail.setPreprocessPath(subtaskDetail.getPreprocessPath() + "/images");
-                subtaskDetail.setType(algorithmModelVo.getTestUrl() + ";;;" + records.get(algName));
-                copyFilesToPath(batchNum, subtaskPath, false);
-                subtaskDetailsService.insert(subtaskDetail);
+                subtaskDetail.setType(algorithmModelVo.getTestUrl() + ";;;" + (records != null ? records.get(algName) : ""));
+                if (hasTestSet(batchNum)) {
+                    copyFilesToPath(batchNum, subtaskPath, false);
+                    subtaskDetailsService.insert(subtaskDetail);
+                }
             }
         }
     }