Procházet zdrojové kódy

[FEATURE] 手动启动任务

Suuuuuukang před 10 měsíci
rodič
revize
46b2c6a11f

+ 10 - 10
taais-modules/taais-biz/src/main/java/com/taais/biz/component/ScheduledTasks.java

@@ -26,14 +26,14 @@ public class ScheduledTasks {
     //    log.info("ScheduledTasks.runTask end");
     //}
 
-    @Scheduled(fixedRate = 30000)
-    public void taskRun() {
-        try {
-            log.info("ScheduledTasks.taskRun start");
-            targetIdentificationTaskService.taskRun();
-            log.info("ScheduledTasks.taskRun end");
-        } catch (Exception e) {
-            log.error("ScheduledTasks.taskRun error", e);
-        }
-    }
+    //@Scheduled(fixedRate = 30000000)
+    //public void taskRun() {
+    //    try {
+    //        log.info("ScheduledTasks.taskRun start");
+    //        targetIdentificationTaskService.taskRun();
+    //        log.info("ScheduledTasks.taskRun end");
+    //    } catch (Exception e) {
+    //        log.error("ScheduledTasks.taskRun error", e);
+    //    }
+    //}
 }

+ 3 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/PublicController.java

@@ -27,6 +27,7 @@ import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
 
 import java.util.Arrays;
+import java.util.Date;
 
 /**
  * @author allen
@@ -72,6 +73,8 @@ public class PublicController extends BaseController {
             detailsBo.setStatus(resultDTO.getStatus() != 200 ? BizConstant.TASK_STATUS_FAILED :
                 resultDTO.getMsg().contains("finish") ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_PROCESSING);
             detailsBo.setRemarks(JSON.toJSONString(Arrays.asList(resultDTO.getMin(), resultDTO.getMax(), resultDTO.getAverage())));
+            detailsBo.setEndTime(new Date());
+            detailsBo.setCostSecond(detailsBo.getEndTime().getTime() - detailsBo.getStartTime().getTime());
             detailsService.update(detailsBo);
 
             // 保存模型

+ 60 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/TargetIdentificationSubtaskDetailsController.java

@@ -2,11 +2,14 @@ package com.taais.biz.controller;
 
 import java.io.File;
 import java.util.ArrayList;
+import java.util.Date;
 import java.util.List;
 import java.util.Map;
 
 import com.taais.biz.constant.BizConstant;
+import com.taais.biz.domain.TargetIdentificationSubtask;
 import com.taais.biz.domain.TargetIdentificationSubtaskDetails;
+import com.taais.biz.domain.vo.TargetIdentificationSubtaskVo;
 import com.taais.biz.service.impl.TargetIdentificationSubtaskServiceImpl;
 import com.taais.biz.service.impl.TargetIdentificationTaskServiceImpl;
 import com.taais.biz.utils.ZipDirectory;
@@ -14,6 +17,7 @@ import com.taais.common.core.utils.StringUtils;
 import lombok.RequiredArgsConstructor;
 import jakarta.servlet.http.HttpServletResponse;
 import cn.dev33.satoken.annotation.SaCheckPermission;
+import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Value;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.validation.annotation.Validated;
@@ -36,6 +40,7 @@ import com.taais.common.core.core.page.PageResult;
  * @author 0
  * 2024-08-17
  */
+@Slf4j
 @Validated
 @RequiredArgsConstructor
 @RestController
@@ -120,12 +125,66 @@ public class TargetIdentificationSubtaskDetailsController extends BaseController
         return CommonResult.success();
     }
 
+    /**
+     * 执行训练任务
+     * @param params
+     * @return
+     */
+    @PostMapping("/startTask")
+    public CommonResult<Void> startTask(@RequestBody Map<String, String> params) {
+        try {
+            Long taskId = Long.parseLong(params.get("taskId"));
+            TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(taskId);
+            if (!(BizConstant.TASK_STATUS_PENDING.equals(details.getStatus()) ||
+                BizConstant.TASK_STATUS_FAILED.equals(details.getStatus()))) {
+                return CommonResult.fail("任务正在执行中,请勿重复执行!");
+            }
+
+            details.setStartTime(new Date());
+            targetIdentificationSubtaskDetailsService.updateById(details);
+            subtaskService.executeOneTask(taskId);
+            return CommonResult.success();
+        } catch (Exception e) {
+            return CommonResult.fail(e.getMessage());
+        }
+    }
+
+    /**
+     * 执行测试或验证任务
+     * @param params
+     * @return
+     */
     @PostMapping("/execute")
-    public CommonResult<Void> add(@RequestBody Map<String, String> params) {
+    public CommonResult<Void> executeTask(@RequestBody Map<String, String> params) {
         try {
             Long taskId = Long.parseLong(params.get("taskId"));
             TargetIdentificationSubtaskDetails details = targetIdentificationSubtaskDetailsService.getById(taskId);
+
+            // check if training task is done
+            TargetIdentificationSubtask subtask = subtaskService.getById(details.getSubtaskId());
+            List<TargetIdentificationSubtaskVo> subtaskList = subtaskService.getSubtaskList(subtask.getTaskId());
+            List<Long> trainingTaskIds = new ArrayList<>();
+            subtaskList.forEach(subtaskVo -> {
+                if (subtaskVo.getName().contains("训练")) {
+                    trainingTaskIds.add(subtaskVo.getId());
+                }
+            });
+            boolean isAllTrainingTaskDone = trainingTaskIds.stream().allMatch(trainingTaskId -> {
+                List<TargetIdentificationSubtaskDetailsBo> list = targetIdentificationSubtaskDetailsService.getBySubtaskId(trainingTaskId);
+                return list.stream().allMatch(taskDetails -> BizConstant.TASK_STATUS_SUCCEED.equals(taskDetails.getStatus()));
+            });
+
+            if (!isAllTrainingTaskDone) {
+                return CommonResult.fail("训练任务未完成,请先完成训练任务!");
+            }
+
+            if (!(BizConstant.TASK_STATUS_PENDING.equals(details.getStatus()) ||
+                BizConstant.TASK_STATUS_FAILED.equals(details.getStatus()))) {
+                return CommonResult.fail("任务正在执行中,请勿重复执行!");
+            }
+
             details.setStatus(BizConstant.TASK_STATUS_PENDING);
+            details.setStartTime(new Date());
             targetIdentificationSubtaskDetailsService.updateById(details);
 
             subtaskService.executeOneTask(taskId);

+ 5 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/mapper/TargetIdentificationSubtaskMapper.java

@@ -1,9 +1,12 @@
 package com.taais.biz.mapper;
 
 import com.mybatisflex.core.BaseMapper;
+import com.taais.biz.domain.vo.TargetIdentificationSubtaskVo;
 import org.apache.ibatis.annotations.Mapper;
 import com.taais.biz.domain.TargetIdentificationSubtask;
 
+import java.util.List;
+
 /**
  * 目标识别子任务Mapper接口
  *
@@ -13,4 +16,6 @@ import com.taais.biz.domain.TargetIdentificationSubtask;
 @Mapper
 public interface TargetIdentificationSubtaskMapper extends BaseMapper<TargetIdentificationSubtask> {
     TargetIdentificationSubtask getAvailableSubtask(long taskId);
+
+    List<TargetIdentificationSubtaskVo> selectList(Long taskId);
 }

+ 1 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/service/ITargetIdentificationSubtaskDetailsService.java

@@ -63,4 +63,5 @@ public interface ITargetIdentificationSubtaskDetailsService extends IBaseService
      */
     boolean deleteByIds(Long[] ids);
 
+    List<TargetIdentificationSubtaskDetailsBo> getBySubtaskId(Long trainingTaskId);
 }

+ 5 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskDetailsServiceImpl.java

@@ -163,6 +163,11 @@ public class TargetIdentificationSubtaskDetailsServiceImpl extends BaseServiceIm
         return this.removeByIds(Arrays.asList(ids));
     }
 
+    @Override
+    public List<TargetIdentificationSubtaskDetailsBo> getBySubtaskId(Long trainingTaskId) {
+        return targetIdentificationSubtaskDetailsMapper.getTargetIdentificationSubtaskDetailsListBySubtaskId(trainingTaskId);
+    }
+
     public Long getAvailableTask() {
         List<Long> list = null;
         list = targetIdentificationSubtaskDetailsMapper.getByTaskType("训练");

+ 18 - 10
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskServiceImpl.java

@@ -258,24 +258,32 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
         // send http
         System.out.println("todo request: " + algorithmRequestDto.toString());
 
-        String res = HttpUtil.post(url, JSONUtil.toJsonStr(algorithmRequestDto));
-        log.info("res is: {}", res);
-        if (res != null) {
-            log.info("version is : {}", detail.getVersion());
-            try {
+        try {
+            String res = HttpUtil.post(url, JSONUtil.toJsonStr(algorithmRequestDto));
+            log.info("res is: {}", res);
+            if (res != null) {
+                //log.info("version is : {}", detail.getVersion());
                 JSONObject jsonObject = JSON.parseObject(res);
                 if (jsonObject.getInteger("status") == 200) {
                     detail.setStatus(BizConstant.TASK_STATUS_SUCCEED);
                 } else {
                     detail.setStatus(BizConstant.TASK_STATUS_FAILED);
+                    detail.setEndTime(new Date());
+                    detail.setCostSecond(detail.getEndTime().getTime() - detail.getStartTime().getTime());
                 }
-            } catch (Exception e) {
-                log.error("http request error: {}", e.getMessage());
-                detail.setStatus(BizConstant.TASK_STATUS_FAILED);
-            } finally {
-                detail.setRemarks("REMARKS");
             }
+        } catch (Exception e) {
+            log.error("http request error: {}", e.getMessage());
+            detail.setStatus(BizConstant.TASK_STATUS_FAILED);
+            detail.setEndTime(new Date());
+            detail.setCostSecond(detail.getEndTime().getTime() - detail.getStartTime().getTime());
+        } finally {
+            detail.setRemarks("REMARKS");
             detailsService.update(detail);
         }
     }
+
+    public List<TargetIdentificationSubtaskVo> getSubtaskList(Long taskId) {
+        return targetIdentificationSubtaskMapper.selectList(taskId);
+    }
 }

+ 4 - 0
taais-modules/taais-biz/src/main/resources/mapper/identification/TargetIdentificationSubtaskMapper.xml

@@ -15,4 +15,8 @@
             t.create_time DESC
         LIMIT 1
     </select>
+    <select id="selectList" resultType="com.taais.biz.domain.vo.TargetIdentificationSubtaskVo">
+        select * from target_identification_subtask
+        where task_id = #{taskId}
+    </select>
 </mapper>