浏览代码

[FEATURE] 执行目标识别算法任务

Suuuuuukang 9 月之前
父节点
当前提交
dda598da11

+ 14 - 10
taais-modules/taais-biz/src/main/java/com/taais/biz/component/ScheduledTasks.java

@@ -19,17 +19,21 @@ public class ScheduledTasks {
     @Resource
     TargetIdentificationTaskServiceImpl targetIdentificationTaskService;
 
-    @Scheduled(fixedRate = 10000)
-    public void runTask() {
-        log.info("ScheduledTasks.runTask start");
-        algorithmTaskService.taskRun();
-        log.info("ScheduledTasks.runTask end");
-    }
+    //@Scheduled(fixedRate = 10000)
+    //public void runTask() {
+    //    log.info("ScheduledTasks.runTask start");
+    //    algorithmTaskService.taskRun();
+    //    log.info("ScheduledTasks.runTask end");
+    //}
 
-    @Scheduled(fixedRate = 10000)
+    @Scheduled(fixedRate = 30000)
     public void taskRun() {
-        log.info("ScheduledTasks.taskRun start");
-        targetIdentificationTaskService.taskRun();
-        log.info("ScheduledTasks.taskRun end");
+        try {
+            log.info("ScheduledTasks.taskRun start");
+            targetIdentificationTaskService.taskRun();
+            log.info("ScheduledTasks.taskRun end");
+        } catch (Exception e) {
+            log.error("ScheduledTasks.taskRun error", e);
+        }
     }
 }

+ 1 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/constant/BizConstant.java

@@ -69,6 +69,6 @@ public class BizConstant {
     public static final String P_CURVE = "P_curve.png";
     public static final String F1_CURVE = "F1_curve.png";
     public static final String ORIGINAL_IMAGE= "原始图片";
-    public static final String DOCKER_BASE_PATH= "/workspace";
+    public static final String DOCKER_BASE_PATH= "/home/ObjectDetection_Web";
     public static final String DOCKER_PT_PATH= "weights/best.pt";
 }

+ 4 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/PublicController.java

@@ -34,6 +34,10 @@ public class PublicController extends BaseController {
 
     @Resource
     private IAlgorithmTaskService algorithmTaskService;
+
+
+
+    // todo: 2024080906
     @PostMapping("/taskResult")
     public CommonResult<Void> taskResult(@RequestBody TaskResultDTO resultDTO) {
         log.info("taskResult start,params:{}",resultDTO);

+ 16 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/TargetIdentificationSubtaskDetailsController.java

@@ -1,7 +1,9 @@
 package com.taais.biz.controller;
 
 import java.util.List;
+import java.util.Map;
 
+import com.taais.biz.service.impl.TargetIdentificationSubtaskServiceImpl;
 import lombok.RequiredArgsConstructor;
 import jakarta.servlet.http.HttpServletResponse;
 import cn.dev33.satoken.annotation.SaCheckPermission;
@@ -34,6 +36,9 @@ public class TargetIdentificationSubtaskDetailsController extends BaseController
     @Resource
     private ITargetIdentificationSubtaskDetailsService targetIdentificationSubtaskDetailsService;
 
+    @Resource
+    private TargetIdentificationSubtaskServiceImpl subtaskService;
+
     /**
      * 查询目标识别子任务列表
      */
@@ -106,4 +111,15 @@ public class TargetIdentificationSubtaskDetailsController extends BaseController
         }
         return CommonResult.success();
     }
+
+    @PostMapping("/execute")
+    public CommonResult<Void> add(@RequestBody Map<String, String> params) {
+        try {
+            Long taskId = Long.parseLong(params.get("taskId"));
+            subtaskService.executeOneTask(taskId);
+            return CommonResult.success();
+        } catch (Exception e) {
+            return CommonResult.fail(e.getMessage());
+        }
+    }
 }

+ 2 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/TargetIdentificationSubtaskDetails.java

@@ -70,6 +70,8 @@ private static final long serialVersionUID = 1L;
     /** $column.columnComment */
     private String log;
 
+
+
     /** 备注 */
     private String remarks;
 

+ 2 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/vo/TargetIdentificationSubtaskDetailsVo.java

@@ -48,6 +48,8 @@ private static final long serialVersionUID = 1L;
     @ExcelProperty(value = "算法")
     private Long algorithmId;
 
+    private String type;
+
     /** 数据批次号 */
     @ExcelProperty(value = "数据批次号")
     private String dataBatchNums;

+ 2 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/mapper/TargetIdentificationSubtaskDetailsMapper.java

@@ -2,7 +2,6 @@ package com.taais.biz.mapper;
 
 import com.mybatisflex.core.BaseMapper;
 import com.taais.biz.domain.bo.TargetIdentificationSubtaskDetailsBo;
-import com.taais.biz.domain.vo.TargetIdentificationSubtaskDetailsVo;
 import org.apache.ibatis.annotations.Mapper;
 import com.taais.biz.domain.TargetIdentificationSubtaskDetails;
 
@@ -18,4 +17,6 @@ import java.util.List;
 public interface TargetIdentificationSubtaskDetailsMapper extends BaseMapper<TargetIdentificationSubtaskDetails> {
 
     List<TargetIdentificationSubtaskDetailsBo> getTargetIdentificationSubtaskDetailsListBySubtaskId(Long subtaskId);
+
+    List<Long> getByTaskType(String taskType);
 }

+ 23 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskDetailsServiceImpl.java

@@ -1,5 +1,6 @@
 package com.taais.biz.service.impl;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
@@ -87,6 +88,10 @@ public class TargetIdentificationSubtaskDetailsServiceImpl extends BaseServiceIm
 
     }
 
+    public TargetIdentificationSubtaskDetailsBo getById(Long id) {
+        return this.getOneAs(query().where(TARGET_IDENTIFICATION_SUBTASK_DETAILS.ID.eq(id)), TargetIdentificationSubtaskDetailsBo.class);
+    }
+
     /**
      * 查询目标识别子任务列表
      *
@@ -123,6 +128,7 @@ public class TargetIdentificationSubtaskDetailsServiceImpl extends BaseServiceIm
         TargetIdentificationSubtaskDetails targetIdentificationSubtaskDetails =MapstructUtils.convert(targetIdentificationSubtaskDetailsBo, TargetIdentificationSubtaskDetails. class);
         targetIdentificationSubtaskDetails.setSubtaskId(targetIdentificationSubtaskDetailsBo.getSubtaskId());
         targetIdentificationSubtaskDetails.setType(targetIdentificationSubtaskDetailsBo.getType());
+        targetIdentificationSubtaskDetails.setIndex(targetIdentificationSubtaskDetailsBo.getIndex());
         System.out.println("convert subdetail: " + targetIdentificationSubtaskDetails);
         System.out.println("convert before subdetail: " + targetIdentificationSubtaskDetailsBo);
         return this.save(targetIdentificationSubtaskDetails);//使用全局配置的雪花算法主键生成器生成ID值
@@ -137,6 +143,7 @@ public class TargetIdentificationSubtaskDetailsServiceImpl extends BaseServiceIm
     @Override
     public boolean update(TargetIdentificationSubtaskDetailsBo targetIdentificationSubtaskDetailsBo) {
         TargetIdentificationSubtaskDetails targetIdentificationSubtaskDetails =MapstructUtils.convert(targetIdentificationSubtaskDetailsBo, TargetIdentificationSubtaskDetails. class);
+        System.out.println("convert subdetail: " + targetIdentificationSubtaskDetails);
         if (ObjectUtil.isNotNull(targetIdentificationSubtaskDetails) && ObjectUtil.isNotNull(targetIdentificationSubtaskDetails.getId())){
             boolean updated = this.updateById(targetIdentificationSubtaskDetails);
                 return updated;
@@ -156,4 +163,20 @@ public class TargetIdentificationSubtaskDetailsServiceImpl extends BaseServiceIm
         return this.removeByIds(Arrays.asList(ids));
     }
 
+    public Long getAvailableTask() {
+        List<Long> list = null;
+        list = targetIdentificationSubtaskDetailsMapper.getByTaskType("训练");
+        if (!list.isEmpty()) {
+            return list.get(0);
+        }
+        list = targetIdentificationSubtaskDetailsMapper.getByTaskType("验证");
+        if (!list.isEmpty()) {
+            return list.get(0);
+        }
+        list = targetIdentificationSubtaskDetailsMapper.getByTaskType("测试");
+        if (!list.isEmpty()) {
+            return list.get(0);
+        }
+        return null;
+    }
 }

+ 76 - 82
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskServiceImpl.java

@@ -4,6 +4,8 @@ import java.lang.reflect.Type;
 import java.util.*;
 
 import cn.hutool.core.util.ObjectUtil;
+import com.alibaba.fastjson2.JSON;
+import com.alibaba.fastjson2.JSONArray;
 import com.google.gson.Gson;
 import com.google.gson.reflect.TypeToken;
 import com.mybatisflex.core.paginate.Page;
@@ -14,9 +16,11 @@ import com.taais.biz.domain.bo.TargetIdentificationSubtaskDetailsBo;
 import com.taais.biz.domain.dto.AlgorithmConfigParamDto;
 import com.taais.biz.domain.dto.AlgorithmRequestDto;
 import com.taais.biz.domain.vo.AlgorithmTaskConfigurationVo;
+import com.taais.biz.domain.vo.TargetIdentificationSubtaskDetailsVo;
 import com.taais.common.core.constant.Constants;
 import com.taais.common.core.utils.MapstructUtils;
 import com.taais.common.core.utils.StringUtils;
+import com.taais.common.log.annotation.Log;
 import com.taais.common.orm.core.page.PageQuery;
 import com.taais.common.core.core.page.PageResult;
 import com.taais.common.orm.core.service.impl.BaseServiceImpl;
@@ -165,99 +169,89 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
         doProcess(targetIdentificationSubtask);
     }
 
+    public void executeOneTask(Long taskId) {
+        TargetIdentificationSubtaskDetailsBo details = detailsService.getById(taskId);
+        details.setStatus(BizConstant.TASK_STATUS_PENDING);
+        detailsService.update(details);
+        log.info("details: {}", details);
+        doProcessSubTaskDetail(details);
+    }
+
     private void doProcess(TargetIdentificationSubtask subtask) {
+        List<TargetIdentificationSubtaskDetailsBo> detailsList = detailsService.getTargetIdentificationSubtaskDetailsListBySubtaskId(subtask.getId());
+
+        for (TargetIdentificationSubtaskDetailsBo detail : detailsList) {
+            doProcessSubTaskDetail(detail);
+        }
+    }
+
+
+    /**
+     * 执行子任务详情
+     * @param detail
+     */
+    // todo 20240903-对接接口
+    private void doProcessSubTaskDetail(TargetIdentificationSubtaskDetailsBo detail) {
+        final String MINI_PREFIX = "/task";
         Long algorithmId = null;
         Long modelId = null;
         String parameters = null;
         String preprocessPath = null;
         String resultPath = null;
         String url = null;
-        StringJoiner errorMsg = null;
 
-        List<TargetIdentificationSubtaskDetailsBo> detailsList = detailsService.getTargetIdentificationSubtaskDetailsListBySubtaskId(subtask.getId());
+        url = detail.getType();
+        parameters = detail.getParameters();
+        preprocessPath = detail.getPreprocessPath();
+        resultPath = detail.getResultPath();
+        AlgorithmRequestDto algorithmRequestDto = new AlgorithmRequestDto();
+        algorithmRequestDto.setBizType(BizConstant.TYPE_DATA_BIZ_PROCESS);
+        algorithmRequestDto.setBizId(detail.getId());
+        algorithmRequestDto.setSourcePath(BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + preprocessPath);
+        algorithmRequestDto.setResultPath(BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + resultPath);
+        algorithmRequestDto.setLogPath(BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + resultPath + "/log/log.log");
+        algorithmRequestDto.setOtherParams(new HashMap<>());
 
-        for (TargetIdentificationSubtaskDetailsBo detail : detailsList) {
-            url = detail.getType();
-            errorMsg = new StringJoiner(System.lineSeparator());
-            algorithmId = detail.getAlgorithmId();
-            modelId = detail.getAlgorithmId();
-            parameters = detail.getParameters();
-            preprocessPath = detail.getPreprocessPath();
-            resultPath = detail.getResultPath();
-            //AlgorithmConfig config = algorithmConfigService.getById(algorithmId);
-            //if (config == null) {
-            //    log.error("算法配置未找到!!!algorithmId:{}", algorithmId);
-            //    errorMsg.add("算法配置未找到!!!");
-            //} else {
-            //    url = config.getAlgorithmAddress();
-            //}
-            //AlgorithmModel model = null;
-            //if (modelId != null) {
-            //    model = modelService.getById(modelId);
-            //}
-            // send http
-            AlgorithmRequestDto algorithmRequestDto = new AlgorithmRequestDto();
-            algorithmRequestDto.setBizType(BizConstant.TYPE_DATA_BIZ_PROCESS);
-            algorithmRequestDto.setBizId(subtask.getId());
-            algorithmRequestDto.setSourcePath(BizConstant.DOCKER_BASE_PATH + preprocessPath);
-            algorithmRequestDto.setResultPath(BizConstant.DOCKER_BASE_PATH + resultPath);
-            algorithmRequestDto.setLogPath(BizConstant.DOCKER_BASE_PATH + resultPath);
-            Gson gson = new Gson();
-            Type listType = new TypeToken<List<AlgorithmConfigParamDto>>() {}.getType();
-            if (StringUtils.isNotEmpty(parameters)) {
-                List<AlgorithmConfigParamDto> paramDtoList = gson.fromJson(parameters, listType);
-                Map<String, Object> otherParams = new HashMap<>(paramDtoList.size());
-                //if (model == null) {
-                //    log.error("模型配置未找到!!!modelId:{}", modelId);
-                //    errorMsg.add("模型配置未找到!!!");
-                //    // 找到训练的模型地址
-                //    String trainModelPath = mapper.getTrainModelPath(subtask.getId());
-                //    if (StringUtils.isNotEmpty(trainModelPath)){
-                //        otherParams.put("weight_path", BizConstant.DOCKER_BASE_PATH + trainModelPath + BizConstant.DOCKER_PT_PATH);
-                //    }
-                //} else {
-                //    SysOssVo modelOss = ossService.getById(Long.valueOf(model.getModelAddress()));
-                //    otherParams.put("pretrained_model", BizConstant.DOCKER_BASE_PATH + StringUtils.substringAfter(modelOss.getFileName(), Constants.RESOURCE_PREFIX));
-                //}
-                for (AlgorithmConfigParamDto algorithmConfigParamDto : paramDtoList) {
-                    String value = StringUtils.isNotEmpty(algorithmConfigParamDto.getValue()) ? algorithmConfigParamDto.getValue() : algorithmConfigParamDto.getDefaultValue();
-                    if(NumberUtils.isCreatable(value)){
-                        otherParams.put(algorithmConfigParamDto.getAgName(), NumberUtils.createNumber(value));
-                    } else {
-                        otherParams.put(algorithmConfigParamDto.getAgName(), value);
-                    }
-                }
-                algorithmRequestDto.setOtherParams(otherParams);
-            }
-            String httpResult = null;
-            Mono<String> response = null;
-            if (StringUtils.isEmpty(url)) {
-                errorMsg.add("url是空!!!");
-            } else {
-                log.info("http post url:{},body:{}", url,algorithmRequestDto);
-                WebClient webClient = WebClient.builder().build();
-                response = webClient.post()
-                    .uri(url)
-                    .bodyValue(algorithmRequestDto)
-                    .retrieve()
-                    .bodyToMono(String.class);
-            }
-            // process httpResult
-            log.info("httpResult:{}", httpResult);
-            // update AlgorithmBizProcess
-            TargetIdentificationSubtaskDetails algorithmBizProcess = MapstructUtils.convert(detail, TargetIdentificationSubtaskDetails.class);
-            if (errorMsg.length() != 0) {
-                algorithmBizProcess.setStatus(BizConstant.TASK_STATUS_FAILED);
-                algorithmBizProcess.setRemarks(errorMsg.toString());
-            } else {
-                algorithmBizProcess.setStatus(BizConstant.TASK_STATUS_PROCESSING);
-                algorithmBizProcess.setStartTime(new Date());
-            }
-            detailsService.updateById(algorithmBizProcess);
-            if (response != null) {
+        String taskName = detail.getName();
+        if (taskName.contains("训练")) {
+            log.info("train");
+        } else if (taskName.contains("验证")) {
+            String[] urls = url.split(";;;");
+            url = urls[0];
+            algorithmRequestDto.getOtherParams().put("weight_path", BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + urls[1] + "/result/weights/best.pt");
+        } else if (taskName.contains("测试")) {
+            String[] urls = url.split(";;;");
+            url = urls[0];
+            algorithmRequestDto.getOtherParams().put("weight_path", BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + urls[1] + "/result/weights/best.pt");
+        } else {
+            log.error("taskName error: " + taskName);
+            return;
+        }
+        // send http
+        System.out.println("todo: " + algorithmRequestDto.toString());
+
+        String httpResult = null;
+        Mono<String> response = null;
+        log.info("http post url:{},body:{}", url, algorithmRequestDto);
+        WebClient webClient = WebClient.builder().build();
+        response = webClient.post()
+            .uri(url)
+            .bodyValue(algorithmRequestDto)
+            .retrieve()
+            .bodyToMono(String.class);
+
+        if (response != null) {
+            log.info("version is : {}", detail.getVersion());
+            try {
                 httpResult = response.block();
-                log.info("targetIdentificationSubtaskDetail id: {} http response {}", subtask.getId(), httpResult);
+                detail.setStatus(BizConstant.TASK_STATUS_PROCESSING);
+            } catch (Exception e) {
+                log.error("http request error: {}", e.getMessage());
+                detail.setStatus(BizConstant.TASK_STATUS_FAILED);
+            } finally {
+                detail.setRemarks("REMARKS");
             }
+            detailsService.update(detail);
         }
     }
 }

+ 135 - 19
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationTaskServiceImpl.java

@@ -2,9 +2,7 @@ package com.taais.biz.service.impl;
 
 import java.io.File;
 import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Random;
+import java.util.*;
 
 import cn.hutool.core.util.ObjectUtil;
 import com.esotericsoftware.minlog.Log;
@@ -35,6 +33,8 @@ import com.taais.biz.domain.TargetIdentificationTask;
 import com.taais.biz.domain.bo.TargetIdentificationTaskBo;
 import com.taais.biz.domain.vo.TargetIdentificationTaskVo;
 
+import javax.security.auth.Subject;
+
 import static com.taais.biz.domain.table.TargetIdentificationTaskTableDef.TARGET_IDENTIFICATION_TASK;
 
 /**
@@ -49,9 +49,9 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
     private TargetIdentificationTaskMapper targetIdentificationTaskMapper;
 
     @Resource
-    private ITargetIdentificationSubtaskService subtaskService;
+    private TargetIdentificationSubtaskServiceImpl subtaskService;
     @Resource
-    private ITargetIdentificationSubtaskDetailsService subtaskDetailsService;
+    private TargetIdentificationSubtaskDetailsServiceImpl subtaskDetailsService;
 
     @Resource
     private IAlgorithmTaskConfigurationService algorithmTaskConfigurationService;
@@ -170,11 +170,15 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
 
     @Override
     public void taskRun() {
-        TargetIdentificationTask task = targetIdentificationTaskMapper.getAvailableTask();
-        if (task == null) {
-            return;
+        Long id = subtaskDetailsService.getAvailableTask();
+        if (id != null) {
+            subtaskService.executeOneTask(id);
         }
-        subtaskService.taskRun(task);
+        //TargetIdentificationTask task = targetIdentificationTaskMapper.getAvailableTask();
+        //if (task == null) {
+        //    return;
+        //}
+        //subtaskService.taskRun(task);
     }
 
     @Override
@@ -195,22 +199,23 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
         taskBo.setId(task.getId());
         System.out.println("taskBo is: " + taskBo);
 
+        Map<String, String> records = null;
         // 创建训练子任务
         if(taskDto.getTaskItemList().contains(CreateTargetIdentificationTaskDto.TASK_TYPE_SINGLE_DATA_AND_MORE_ALGORITHM)){
-            createTrainTask(taskBo.getId(),taskDto);
+            records = createTrainTask(taskBo.getId(),taskDto);
         }
         // 创建测试子任务
         if (taskDto.getTaskItemList().contains(CreateTargetIdentificationTaskDto.TASK_TYPE_SINGLE_DATA_AND_MORE_ALGORITHM)){
-            createTestTask(taskBo.getId(),taskDto);
+            createTestTask(taskBo.getId(),taskDto, records);
         }
         return null;
     }
 
-    private void createTrainTask(Long taskId, CreateTargetIdentificationTaskDto taskDto) {
+    private Map<String, String> createTrainTask(Long taskId, CreateTargetIdentificationTaskDto taskDto) {
         List<TaskDto> algTaskList = taskDto.getAlgTaskList();
         List<String> trainBatchNumList = taskDto.getTrainBatchNumList();
         if (trainBatchNumList.isEmpty()) {
-            return;
+            return null;
         }
         TargetIdentificationSubtaskBo subtask = new TargetIdentificationSubtaskBo();
         subtask.setName("训练");
@@ -222,13 +227,16 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
         subtask.setId(savedTask.getId());
         System.out.println("subtask is: " + subtask);
 
+        Map<String, String> records = new HashMap<>();
+
         for (TaskDto algTask : algTaskList) {
             // 通过算法id 获取算法配置
             Long algorithmId = algTask.getAlgorithmId();
             AlgorithmTaskConfigurationVo algorithmModelVo = algorithmTaskConfigurationService.selectById(algorithmId);
 
-            String algUrl = algorithmModelVo.getTestUrl();
+            String algUrl = algorithmModelVo.getTrainUrl();
             String algName = algorithmModelVo.getName();
+
             List<String> params = List.of(algTask.getParams().split(";;;"));
 
             for (int i = 0; i < trainBatchNumList.size(); i++) {
@@ -236,14 +244,14 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                 TargetIdentificationSubtaskDetailsBo subtaskDetail = new TargetIdentificationSubtaskDetailsBo();
                 // 通过算法id获取算法配置
                 subtaskDetail.setSubtaskId(subtask.getId());
-                subtaskDetail.setName(algName);
+                subtaskDetail.setName(algName + "_训练");
                 subtaskDetail.setStatus(BizConstant.TASK_STATUS_PENDING);
                 subtaskDetail.setAlgorithmId(algTask.getAlgorithmId());
                 subtaskDetail.setType(algUrl);
                 subtaskDetail.setDataBatchNums(batchNum);
                 subtaskDetail.setParameters(params.get(0));
-                // todo 获取预处理路径
                 String subtaskPath = "/" + UUID.randomUUID().toString().replace("-", "_");
+                records.put(algName, subtaskPath);
                 subtaskDetail.setPreprocessPath(subtaskPath);
                 subtaskDetail.setResultPath(subtaskPath + "/result");
                 subtaskDetail.setIndex((long) i);
@@ -252,6 +260,7 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                 subtaskDetailsService.insert(subtaskDetail);
             }
         }
+        return records;
     }
 
     public static final String PATH_PREFIX = "ObjectDetection_Web/task";
@@ -318,7 +327,102 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
         }
     }
 
-    private void createTestTask(Long taskId, CreateTargetIdentificationTaskDto taskDto) {
+    /**
+     * 移动文件到对应文件夹
+     * @param batch
+     * @param path
+     */
+    private void copyFilesToPath(String batch, String path, boolean moveLabelFile) {
+        String[] batches = batch.split(",");
+
+        File dir = new File(PATH_PREFIX + path);
+        if (!dir.exists()) {
+            dir.mkdirs();
+        }
+        dir = new File(PATH_PREFIX + path + "/images");
+        if (!dir.exists()) {
+            dir.mkdirs();
+        }
+        dir = new File(PATH_PREFIX + path + "/labels");
+        if (!dir.exists()) {
+            dir.mkdirs();
+        }
+        dir = new File(PATH_PREFIX + path + "/result");
+        if (!dir.exists()) {
+            dir.mkdirs();
+        }
+
+        for (String batchNum : batches) {
+            List<DataVo> dataVoList = dataService.getDataByBatchNum(batchNum);
+            for (DataVo dataVo : dataVoList) {
+                String[] strings = null;
+                String relativePath = null;
+                File file = null;
+
+                if (dataVo.getLabelurl() != null) {
+                    strings = dataVo.getLabelurl().split("/profile");
+                    relativePath = "ObjectDetection_Web" + strings[strings.length - 1];
+                    file = new File(relativePath);
+                    if (file.exists()) {
+                        if (!moveLabelFile) {
+                            continue;
+                        }
+                        try {
+                            File dist = new File(PATH_PREFIX + path + "/labels/" + file.getName());
+                            FileUtils.copyFile(file, dist);
+                            System.out.println("file dist: " + dist.getAbsolutePath());
+                        } catch (IOException e) {
+                            Log.debug("bug found");
+                            continue;
+                        }
+                    }
+                } else if (moveLabelFile) {
+                    continue;
+                }
+
+                strings = dataVo.getUrl().split("/profile");
+                relativePath = "ObjectDetection_Web" + strings[strings.length - 1];
+                file = new File(relativePath);
+                if (file.exists()) {
+                    try {
+                        File dist = new File(PATH_PREFIX + path + "/images/" + file.getName());
+                        FileUtils.copyFile(file, dist);
+                        System.out.println("file dist: " + dist.getAbsolutePath());
+                    } catch (IOException e) {
+                        Log.debug("bug found");
+                        continue;
+                    }
+                }
+            }
+        }
+    }
+
+    /**
+     * 检查是否有验证集
+     * @param batch
+     */
+    private boolean hasValidationSet(String batch) {
+        String[] batches = batch.split(",");
+
+        for (String batchNum : batches) {
+            List<DataVo> dataVoList = dataService.getDataByBatchNum(batchNum);
+            for (DataVo dataVo : dataVoList) {
+                String[] strings = null;
+
+                if (StringUtils.isNotEmpty(dataVo.getLabelurl())) {
+                    strings = dataVo.getLabelurl().split("/profile");
+                    String relativePath = "ObjectDetection_Web" + strings[strings.length - 1];
+                    File file = new File(relativePath);
+                    if (file.exists()) {
+                        return true;
+                    }
+                }
+            }
+        }
+        return false;
+    }
+
+    private void createTestTask(Long taskId, CreateTargetIdentificationTaskDto taskDto, Map<String, String> records) {
         List<TaskDto> algTaskList = taskDto.getAlgTaskList();
         List<String> testBatchNumList = taskDto.getTestBatchNumList();
         if (testBatchNumList.isEmpty()) {
@@ -354,13 +458,25 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                 subtaskDetail.setType(algUrl);
                 subtaskDetail.setDataBatchNums(batchNum);
                 subtaskDetail.setParameters(params.get(2));
-                // todo 获取预处理路径
                 String subtaskPath = "/" + UUID.randomUUID().toString().replace("-", "_");
                 subtaskDetail.setPreprocessPath(subtaskPath);
                 subtaskDetail.setResultPath(subtaskPath + "/result");
                 subtaskDetail.setIndex((long) i);
-                copyFilesToPath(batchNum, subtaskPath);
 
+                if (hasValidationSet(batchNum)) {
+                    subtaskDetail.setName(algName + "_验证");
+                    copyFilesToPath(batchNum, subtaskPath, true);
+                    subtaskDetail.setType(algorithmModelVo.getVerifyUrl() + ";;;" + records.get(algName));
+                    subtaskDetailsService.insert(subtaskDetail);
+                    subtaskPath = "/" + UUID.randomUUID().toString().replace("-", "_");
+                    subtaskDetail.setPreprocessPath(subtaskPath);
+                    subtaskDetail.setResultPath(subtaskPath + "/result");
+                }
+
+                subtaskDetail.setName(algName + "_测试");
+                subtaskDetail.setPreprocessPath(subtaskDetail.getPreprocessPath() + "/images");
+                subtaskDetail.setType(algorithmModelVo.getTestUrl() + ";;;" + records.get(algName));
+                copyFilesToPath(batchNum, subtaskPath, false);
                 subtaskDetailsService.insert(subtaskDetail);
             }
         }

+ 5 - 0
taais-modules/taais-biz/src/main/resources/mapper/identification/TargetIdentificationSubtaskDetailsMapper.xml

@@ -8,4 +8,9 @@
             resultType="com.taais.biz.domain.bo.TargetIdentificationSubtaskDetailsBo">
         select * from target_identification_subtask_details where subtask_id = #{subtaskId}
     </select>
+    <select id="getByTaskType" resultType="java.lang.Long">
+        select id from target_identification_subtask_details
+                  where name like '%' || #{taskType} || '%' and
+                                                        status = '0'
+    </select>
 </mapper>