allen 11 mesi fa
parent
commit
48decede03

+ 7 - 5
script/sql/postgresql/update_task_and_subtask.sql

@@ -1,11 +1,12 @@
 CREATE OR REPLACE PROCEDURE update_task_and_subtask(
-    IN p_biz_id INT,
+    IN p_biz_id bigint,
     IN p_bizType VARCHAR
 )
 LANGUAGE plpgsql
 AS $$
 BEGIN
-    IF p_bizType = 'dataProcess' THEN
+    IF p_bizType = 'dataBizProcess' THEN
+    RAISE NOTICE 'Starting dataBizProcess';
         -- 检查 subtask 状态 --失败场景暂时不更新上层任务
         -- PERFORM 1 FROM algorithm_biz_process WHERE subtaskId = p_biz_id AND status = '3';
         -- IF FOUND THEN
@@ -30,13 +31,14 @@ BEGIN
             END IF;
 
         --END IF;
-    ELSIF p_bizType = 'dataBizProcess' THEN
+    ELSIF p_bizType = 'dataProcess' THEN
+    RAISE NOTICE 'Starting dataProcess';
         -- 如果algorithm_data_process所有状态都是2(成功),更新 algorithm_subtask 表为2(成功)
-        PERFORM 1 FROM algorithm_data_process WHERE status <> '2' AND sub_task_id = ( SELECT sub_task_id FROM algorithm_biz_process WHERE id = p_biz_id);
+        PERFORM 1 FROM algorithm_data_process WHERE status <> '2' AND sub_task_id = ( SELECT sub_task_id FROM algorithm_data_process WHERE id = p_biz_id);
         IF NOT FOUND THEN
             UPDATE algorithm_subtask
             SET status = '2'
-            WHERE id = ( SELECT sub_task_id FROM algorithm_biz_process WHERE id = p_biz_id);
+            WHERE id = ( SELECT sub_task_id FROM algorithm_data_process WHERE id = p_biz_id);
         END IF;
         -- 如果algorithm_subtask表所有状态都是2(成功),更新 algorithm_task 表为2(成功)
         PERFORM 1 FROM algorithm_subtask WHERE status <> '2' AND task_id = (SELECT task_id FROM algorithm_subtask WHERE id = (SELECT sub_task_id FROM algorithm_data_process WHERE id = p_biz_id));

+ 2 - 0
taais-admin/src/main/java/com/taais/TaaisApplication.java

@@ -3,12 +3,14 @@ package com.taais;
 import org.springframework.boot.SpringApplication;
 import org.springframework.boot.autoconfigure.SpringBootApplication;
 import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
+import org.springframework.scheduling.annotation.EnableScheduling;
 
 /**
  * Km启动程序
  *
  * @author wgk
  */
+@EnableScheduling
 @SpringBootApplication(exclude = {DataSourceAutoConfiguration.class}, scanBasePackages = {"com.taais"})
 public class TaaisApplication {
     public static void main(String[] args) {

+ 9 - 0
taais-modules/taais-biz/pom.xml

@@ -49,6 +49,15 @@
             <groupId>com.google.code.gson</groupId>
             <artifactId>gson</artifactId>
         </dependency>
+        <dependency>
+            <groupId>org.springframework</groupId>
+            <artifactId>spring-webflux</artifactId>
+        </dependency>
+        <!--        <dependency>-->
+<!--            <groupId>org.apache.httpcomponents</groupId>-->
+<!--            <artifactId>httpclient</artifactId>-->
+<!--            <version>4.5.14</version>-->
+<!--        </dependency>-->
 
     </dependencies>
 

+ 22 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/component/ScheduledTasks.java

@@ -0,0 +1,22 @@
+package com.taais.biz.component;
+
+import com.taais.biz.service.IAlgorithmTaskService;
+import jakarta.annotation.Resource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.scheduling.annotation.Scheduled;
+import org.springframework.stereotype.Component;
+
+@Component
+public class ScheduledTasks {
+    private static final Logger log = LoggerFactory.getLogger(ScheduledTasks.class);
+
+    @Resource
+    IAlgorithmTaskService algorithmTaskService;
+    @Scheduled(fixedRate = 10000)
+    public void runTask() {
+        log.info("ScheduledTasks.runTask start");
+//        algorithmTaskService.taskRun();
+        log.info("ScheduledTasks.runTask end");
+    }
+}

+ 13 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/component/WebClientConfig.java

@@ -0,0 +1,13 @@
+package com.taais.biz.component;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.reactive.function.client.WebClient;
+
+@Configuration
+public class WebClientConfig {
+    @Bean
+    public WebClient.Builder webClientBuilder() {
+        return WebClient.builder();
+    }
+}

+ 1 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/constant/BizConstant.java

@@ -63,4 +63,5 @@ public class BizConstant {
     public static final String F1_CURVE = "F1_curve.png";
     public static final String ORIGINAL_IMAGE= "原始图片";
     public static final String DOCKER_BASE_PATH= "/workspace";
+    public static final String DOCKER_PT_PATH= "weights/best.pt";
 }

+ 2 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/mapper/AlgorithmSubtaskMapper.java

@@ -18,4 +18,6 @@ public interface AlgorithmSubtaskMapper extends BaseMapper<AlgorithmSubtask> {
     AlgorithmSubtask getFirstNeedProcessSubtask(Long taskId);
 
     List<AlgorithmSubtaskVo> getSubtaskByTaskId(Long taskId);
+
+    String getTrainModelPath(Long id, Long algorithmId);
 }

+ 33 - 7
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/AlgorithmSubtaskServiceImpl.java

@@ -15,18 +15,22 @@ import com.taais.biz.domain.vo.AlgorithmDataProcessVo;
 import com.taais.biz.domain.vo.AlgorithmSubtaskVo;
 import com.taais.biz.mapper.AlgorithmSubtaskMapper;
 import com.taais.biz.service.*;
+import com.taais.common.core.constant.Constants;
 import com.taais.common.core.core.page.PageResult;
 import com.taais.common.core.utils.MapstructUtils;
 import com.taais.common.core.utils.StringUtils;
-import com.taais.common.core.utils.http.HttpUtils;
 import com.taais.common.orm.core.page.PageQuery;
 import com.taais.common.orm.core.service.impl.BaseServiceImpl;
+import com.taais.system.domain.vo.SysOssVo;
+import com.taais.system.service.ISysOssService;
 import jakarta.annotation.Resource;
 import org.apache.commons.lang3.math.NumberUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.publisher.Mono;
 
 import java.lang.reflect.Type;
 import java.util.*;
@@ -44,6 +48,8 @@ public class AlgorithmSubtaskServiceImpl extends BaseServiceImpl<AlgorithmSubtas
 
     private static final Logger log = LoggerFactory.getLogger(AlgorithmSubtaskServiceImpl.class);
     @Resource
+    private WebClient.Builder webClientBuilder;
+    @Resource
     private AlgorithmSubtaskMapper algorithmSubtaskMapper;
     @Resource
     private IAlgorithmDataSetService dataSetService;
@@ -55,6 +61,8 @@ public class AlgorithmSubtaskServiceImpl extends BaseServiceImpl<AlgorithmSubtas
     private IAlgorithmConfigService algorithmConfigService;
     @Resource
     private IAlgorithmModelService modelService;
+    @Resource
+    private ISysOssService ossService;
 
     @Override
     public QueryWrapper query() {
@@ -227,17 +235,22 @@ public class AlgorithmSubtaskServiceImpl extends BaseServiceImpl<AlgorithmSubtas
             if (StringUtils.isNotEmpty(parameters)) {
                 List<AlgorithmConfigParamDto> paramDtoList = gson.fromJson(parameters, listType);
                 Map<String, Object> otherParams = new HashMap<>(paramDtoList.size());
-                // todo allen 模型没加载进去
                 if (model == null) {
                     log.error("模型配置未找到!!!modelId:{}", modelId);
                     errorMsg.add("模型配置未找到!!!");
                 } else {
-                    otherParams.put("pretrained_model", model.getModelAddress());
+                    SysOssVo modelOss = ossService.getById(Long.valueOf(model.getModelAddress()));
+                    otherParams.put("pretrained_model", BizConstant.DOCKER_BASE_PATH + StringUtils.substringAfter(modelOss.getFileName(), Constants.RESOURCE_PREFIX));
+                    // 找到训练的模型地址
+                    String trainModelPath = mapper.getTrainModelPath(bizProcessVo.getId(), bizProcessVo.getAlgorithmId());
+                    if (StringUtils.isNotEmpty(trainModelPath)){
+                        otherParams.put("weight_path", BizConstant.DOCKER_BASE_PATH + trainModelPath + BizConstant.DOCKER_PT_PATH);
+                    }
                 }
                 for (AlgorithmConfigParamDto algorithmConfigParamDto : paramDtoList) {
                     String value = StringUtils.isNotEmpty(algorithmConfigParamDto.getValue()) ? algorithmConfigParamDto.getValue() : algorithmConfigParamDto.getDefaultValue();
-                    if(NumberUtils.isDigits(value)){
-                        otherParams.put(algorithmConfigParamDto.getAgName(), NumberUtils.createFloat(value));
+                    if(NumberUtils.isCreatable(value)){
+                        otherParams.put(algorithmConfigParamDto.getAgName(), NumberUtils.createNumber(value));
                     } else {
                         otherParams.put(algorithmConfigParamDto.getAgName(), value);
                     }
@@ -248,7 +261,13 @@ public class AlgorithmSubtaskServiceImpl extends BaseServiceImpl<AlgorithmSubtas
             if (StringUtils.isEmpty(url)) {
                 errorMsg.add("url是空!!!");
             } else {
-                httpResult = HttpUtils.sendPost(url, gson.toJson(algorithmRequestDto));
+                WebClient webClient = webClientBuilder.build();
+                Mono<String> response = webClient.post()
+                    .uri(url)
+                    .bodyValue(algorithmRequestDto)
+                    .retrieve()
+                    .bodyToMono(String.class);
+                httpResult = response.block();
             }
             // process httpResult
             log.info("httpResult:{}", httpResult);
@@ -311,7 +330,14 @@ public class AlgorithmSubtaskServiceImpl extends BaseServiceImpl<AlgorithmSubtas
             if (StringUtils.isEmpty(url)) {
                 errorMsg.add("url是空!!!");
             } else {
-                httpResult = HttpUtils.sendPost(url, gson.toJson(algorithmRequestDto));
+//                httpResult = HttpUtils.sendPost(url, gson.toJson(algorithmRequestDto));
+                WebClient webClient = webClientBuilder.build();
+                Mono<String> response = webClient.post()
+                    .uri(url)
+                    .bodyValue(algorithmRequestDto)
+                    .retrieve()
+                    .bodyToMono(String.class);
+                httpResult = response.block();
             }
             // process httpResult
             log.info("httpResult:{}", httpResult);

+ 3 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/AlgorithmTaskServiceImpl.java

@@ -277,7 +277,7 @@ public class AlgorithmTaskServiceImpl extends BaseServiceImpl<AlgorithmTaskMappe
         subtaskService.insert(algorithmSubtask);
         Long index = 0L;
         for (String testDataFolderPath : testDataFolderPathList) {
-            for (TaskDto dto : taskDto.getTrain()) {
+            for (TaskDto dto : taskDto.getTest()) {
                 // 创建算法处理任务
                 AlgorithmBizProcessBo algorithmBizProcessBo = new AlgorithmBizProcessBo();
                 AlgorithmConfig algorithmConfig = algorithmConfigService.getById(dto.getAlgorithmId());
@@ -396,7 +396,7 @@ public class AlgorithmTaskServiceImpl extends BaseServiceImpl<AlgorithmTaskMappe
         subtaskService.insert(algorithmSubtask);
         Long index = 0L;
         for (String reasoningDataFolderPath : reasoningDataFolderPathList) {
-            for (TaskDto dto : taskDto.getTrain()) {
+            for (TaskDto dto : taskDto.getReasoning()) {
                 // 创建算法处理任务
                 AlgorithmBizProcessBo algorithmBizProcessBo = new AlgorithmBizProcessBo();
                 AlgorithmConfig algorithmConfig = algorithmConfigService.getById(dto.getAlgorithmId());
@@ -404,7 +404,7 @@ public class AlgorithmTaskServiceImpl extends BaseServiceImpl<AlgorithmTaskMappe
                 algorithmBizProcessBo.setAlgorithmId(dto.getAlgorithmId());
                 algorithmBizProcessBo.setModelId(dto.getModelId());
                 algorithmBizProcessBo.setStatus(BizConstant.TASK_STATUS_PENDING);
-                algorithmBizProcessBo.setPreprocessPath(reasoningDataFolderPath);
+                algorithmBizProcessBo.setPreprocessPath(reasoningDataFolderPath + BizConstant.IMAGE);
                 algorithmBizProcessBo.setIndex(index);
                 algorithmBizProcessBo.setSubTaskId(algorithmSubtask.getId());
                 bizProcessService.insert(algorithmBizProcessBo);

+ 3 - 3
taais-modules/taais-biz/src/main/resources/mapper/task/AlgorithmBizProcessMapper.xml

@@ -25,11 +25,11 @@
         <result property="remarks" column="remarks" />
     </resultMap>
 
-    <select id="updateTaskAndSubtask" statementType="CALLABLE">
+    <update id="updateTaskAndSubtask" statementType="CALLABLE">
         call update_task_and_subtask(#{bizId,jdbcType=INTEGER,mode=IN}, #{bizType,jdbcType=VARCHAR,mode=IN})
-    </select>
+    </update>
 
     <select id="getProcessBySubtaskId" resultMap="AlgorithmBizProcessResultMap">
-        select * from algorithm_biz_process where sub_task_id = #{value} order by index asc
+        select * from algorithm_biz_process where sub_task_id = #{value} and status = '0' order by index asc
     </select>
 </mapper>

+ 1 - 1
taais-modules/taais-biz/src/main/resources/mapper/task/AlgorithmDataProcessMapper.xml

@@ -26,7 +26,7 @@
     </resultMap>
 
     <select id="getProcessBySubtaskId" resultMap="AlgorithmDataProcessResultMap">
-        select * from algorithm_data_process where sub_task_id = #{value}
+        select * from algorithm_data_process where sub_task_id = #{value} and status = '0' order by index asc
     </select>
 
 </mapper>

+ 24 - 0
taais-modules/taais-biz/src/main/resources/mapper/task/AlgorithmSubtaskMapper.xml

@@ -29,4 +29,28 @@
     select * from algorithm_subtask t where t.task_id = #{taskId} order by index asc
 </select>
 
+    <select id="getTrainModelPath" parameterType="Long" resultType="String">
+        SELECT
+            abp.result_path
+        FROM
+            algorithm_biz_process abp
+        WHERE
+            abp.algorithm_id = (SELECT parent_id FROM algorithm_config ac WHERE ac.id = (SELECT abp2.algorithm_id FROM algorithm_biz_process abp2 WHERE abp2.id = #{bizId}))
+        AND abp.sub_task_id = (
+            SELECT
+                ast.id
+            FROM
+                algorithm_subtask ast
+            WHERE
+                ast.type = '1' AND ast.task_id = (
+                SELECT
+                    ast2.task_id
+                FROM
+                    algorithm_subtask ast2
+                WHERE
+                    ast2.id = ( SELECT abp2.sub_task_id FROM algorithm_biz_process abp2 WHERE abp2.id = #{bizId} )
+            )
+        )
+    </select>
+
 </mapper>