浏览代码

[FIX] 若干修复项,核心是创建任务的数据修复

Suuuuuukang 9 月之前
父节点
当前提交
02b9584faf

+ 7 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/AlgorithmModelController.java

@@ -2,6 +2,7 @@ package com.taais.biz.controller;
 
 import java.util.List;
 
+import com.taais.biz.service.impl.AlgorithmModelServiceImpl;
 import lombok.RequiredArgsConstructor;
 import jakarta.servlet.http.HttpServletResponse;
 import cn.dev33.satoken.annotation.SaCheckPermission;
@@ -32,7 +33,7 @@ import com.taais.common.core.core.page.PageResult;
 @RequestMapping("/ag/model")
 public class AlgorithmModelController extends BaseController {
     @Resource
-    private IAlgorithmModelService algorithmModelService;
+    private AlgorithmModelServiceImpl algorithmModelService;
 
     /**
      * 查询算法模型配置列表
@@ -63,6 +64,11 @@ public class AlgorithmModelController extends BaseController {
         return CommonResult.success(algorithmModelService.selectById(id));
     }
 
+    @GetMapping(value = "algoList")
+    public CommonResult<List<AlgorithmModelVo>> getAlgorithmRelatedModelInfo(Long id) {
+        return CommonResult.success(algorithmModelService.getModelByAlgorithmId(id));
+    }
+
     /**
      * 新增算法模型配置
      */

+ 19 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/PublicController.java

@@ -1,6 +1,8 @@
 package com.taais.biz.controller;
 
+import com.alibaba.fastjson2.JSON;
 import com.taais.biz.constant.BizConstant;
+import com.taais.biz.domain.bo.AlgorithmModelBo;
 import com.taais.biz.domain.bo.TargetIdentificationSubtaskDetailsBo;
 import com.taais.biz.domain.bo.VideoStableStartResultBo;
 import com.taais.biz.domain.dto.TaskResultDTO;
@@ -8,6 +10,7 @@ import com.taais.biz.service.IAlgorithmBizProcessService;
 import com.taais.biz.service.IAlgorithmDataProcessService;
 import com.taais.biz.service.IAlgorithmTaskService;
 import com.taais.biz.service.IVideoStableService;
+import com.taais.biz.service.impl.AlgorithmModelServiceImpl;
 import com.taais.biz.service.impl.TargetIdentificationSubtaskDetailsServiceImpl;
 import com.taais.common.core.core.domain.CommonResult;
 import com.taais.common.log.annotation.Log;
@@ -23,6 +26,8 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
 
+import java.util.Arrays;
+
 /**
  * @author allen
  */
@@ -46,6 +51,9 @@ public class PublicController extends BaseController {
     @Resource
     private IAlgorithmTaskService algorithmTaskService;
 
+    @Resource
+    AlgorithmModelServiceImpl algorithmModelService;
+
 
 
     // todo: 2024080906
@@ -63,7 +71,18 @@ public class PublicController extends BaseController {
             TargetIdentificationSubtaskDetailsBo detailsBo = detailsService.getById(resultDTO.getBizId());
             detailsBo.setStatus(resultDTO.getStatus() != 200 ? BizConstant.TASK_STATUS_FAILED :
                 resultDTO.getMsg().contains("finish") ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_PROCESSING);
+            detailsBo.setRemarks(JSON.toJSONString(Arrays.asList(resultDTO.getMin(), resultDTO.getMax(), resultDTO.getAverage())));
             detailsService.update(detailsBo);
+
+            // 保存模型
+            if (BizConstant.TASK_STATUS_SUCCEED.equals(detailsBo.getStatus()) && detailsBo.getName().contains("训练")) {
+                Long algorithmId = detailsBo.getAlgorithmId();
+                AlgorithmModelBo bo = new AlgorithmModelBo();
+                bo.setAlgorithmId(algorithmId);
+                bo.setModelAddress("/profile/task" + detailsBo.getResultPath() + "weights/best.pt");
+                bo.setModelName(detailsBo.getName() + "_" + detailsBo.getCreateTime().toString());
+                algorithmModelService.insert(bo);
+            }
         } else if (BizConstant.TYPE_DATA_PROCESS.equals(bizType)) {
             errorMsg = dataProcessService.taskResult(resultDTO);
         } else {

+ 12 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/domain/dto/TaskResultDTO.java

@@ -8,4 +8,16 @@ public class TaskResultDTO {
     private String msg;
     private String bizType;
     private Long bizId;
+    private TimeResult max;
+    private TimeResult min;
+    private TimeResult average;
+
+    @Data
+    public class TimeResult {
+        String preprocess;
+        String inference;
+        String postprocess;
+    }
 }
+
+

+ 4 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/mapper/AlgorithmModelMapper.java

@@ -6,6 +6,8 @@ import org.apache.ibatis.annotations.Mapper;
 import com.taais.biz.domain.AlgorithmModel;
 import org.apache.ibatis.annotations.Param;
 
+import java.util.List;
+
 /**
  * 算法模型配置Mapper接口
  *
@@ -17,5 +19,6 @@ public interface AlgorithmModelMapper extends BaseMapper<AlgorithmModel> {
 
     String getModelNameBySubtaskId(@Param("subtaskId") Long subTaskId, @Param("algorithmId") Long algorithmId);
 
-    AlgorithmModelVo getModelByAlgorithmId(@Param("algorithmId") Long algorithmId);
+    List<AlgorithmModelVo> getModelByAlgorithmId(@Param("algorithmId") Long algorithmId);
+    AlgorithmModelVo getModelById(@Param("id") Long id);
 }

+ 5 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/AlgorithmModelServiceImpl.java

@@ -64,7 +64,7 @@ public class AlgorithmModelServiceImpl extends BaseServiceImpl<AlgorithmModelMap
      */
     @Override
     public AlgorithmModelVo selectById(Long id) {
-            return mapper.getModelByAlgorithmId(id);
+            return mapper.getModelById(id);
     }
 
     /**
@@ -92,6 +92,10 @@ public class AlgorithmModelServiceImpl extends BaseServiceImpl<AlgorithmModelMap
         return PageResult.build(page);
     }
 
+    public List<AlgorithmModelVo> getModelByAlgorithmId(Long algorithmId) {
+        return algorithmModelMapper.getModelByAlgorithmId(algorithmId);
+    }
+
     /**
      * 新增算法模型配置
      *

+ 36 - 3
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskServiceImpl.java

@@ -5,18 +5,21 @@ import java.util.*;
 
 import cn.hutool.core.util.ObjectUtil;
 import cn.hutool.http.HttpUtil;
-import cn.hutool.json.JSON;
 import cn.hutool.json.JSONUtil;
+import com.alibaba.fastjson2.JSON;
 import com.alibaba.fastjson2.JSONArray;
+import com.alibaba.fastjson2.JSONObject;
 import com.google.gson.Gson;
 import com.google.gson.reflect.TypeToken;
 import com.mybatisflex.core.paginate.Page;
 import com.mybatisflex.core.query.QueryWrapper;
 import com.taais.biz.constant.BizConstant;
 import com.taais.biz.domain.*;
+import com.taais.biz.domain.bo.AlgorithmModelBo;
 import com.taais.biz.domain.bo.TargetIdentificationSubtaskDetailsBo;
 import com.taais.biz.domain.dto.AlgorithmConfigParamDto;
 import com.taais.biz.domain.dto.AlgorithmRequestDto;
+import com.taais.biz.domain.vo.AlgorithmModelVo;
 import com.taais.biz.domain.vo.AlgorithmTaskConfigurationVo;
 import com.taais.biz.domain.vo.TargetIdentificationSubtaskDetailsVo;
 import com.taais.common.core.constant.Constants;
@@ -59,6 +62,9 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
     @Resource
     private AlgorithmTaskConfigurationServiceImpl algorithmTaskConfigurationService;
 
+    @Resource
+    AlgorithmModelServiceImpl algorithmModelService;
+
     @Override
     public QueryWrapper query() {
         return super.query().from(TARGET_IDENTIFICATION_SUBTASK);
@@ -212,9 +218,31 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
         algorithmRequestDto.setLogPath(BizConstant.DOCKER_BASE_PATH + MINI_PREFIX + resultPath + "/log/log.log");
         algorithmRequestDto.setOtherParams(new HashMap<>());
 
+        boolean hasModelProperty = false;
+        String _modelId = null;
+        try {
+            JSONArray jsonArray = JSON.parseArray(parameters);
+            for (int i = 0; i < jsonArray.size(); i++) {
+                JSONObject object = jsonArray.getJSONObject(i);
+                if ("pretrained_model".equals(object.getString("agName"))) {
+                    hasModelProperty = true;
+                    _modelId = object.getString("modelId");
+                }
+                algorithmRequestDto.getOtherParams().put(object.getString("agName"), object.getString("defaultValue"));
+            }
+        } catch (Exception e) {
+            log.error(e.getMessage());
+        }
+
         String taskName = detail.getName();
         if (taskName.contains("训练")) {
             log.info("train");
+            if (hasModelProperty) {
+                AlgorithmModelVo bo = algorithmModelService.selectById(Long.valueOf(_modelId));
+                String path = bo.getModelAddress().replace("/profile", "/home/ObjectDetection_Web");
+                algorithmRequestDto.getOtherParams().put("pretrained", true);
+                algorithmRequestDto.getOtherParams().put("pretrained_model", path);
+            }
         } else if (taskName.contains("验证")) {
             String[] urls = url.split(";;;");
             url = urls[0];
@@ -228,14 +256,19 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
             return;
         }
         // send http
-        System.out.println("todo: " + algorithmRequestDto.toString());
+        System.out.println("todo request: " + algorithmRequestDto.toString());
 
         String res = HttpUtil.post(url, JSONUtil.toJsonStr(algorithmRequestDto));
         log.info("res is: {}", res);
         if (res != null) {
             log.info("version is : {}", detail.getVersion());
             try {
-                detail.setStatus(BizConstant.TASK_STATUS_SUCCEED);
+                JSONObject jsonObject = JSON.parseObject(res);
+                if (jsonObject.getInteger("status") == 200) {
+                    detail.setStatus(BizConstant.TASK_STATUS_SUCCEED);
+                } else {
+                    detail.setStatus(BizConstant.TASK_STATUS_FAILED);
+                }
             } catch (Exception e) {
                 log.error("http request error: {}", e.getMessage());
                 detail.setStatus(BizConstant.TASK_STATUS_FAILED);

+ 21 - 4
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationTaskServiceImpl.java

@@ -25,6 +25,7 @@ import com.taais.common.orm.core.page.PageQuery;
 import com.taais.common.core.core.page.PageResult;
 import com.taais.common.orm.core.service.impl.BaseServiceImpl;
 import jakarta.annotation.Resource;
+import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.io.FileUtils;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
@@ -44,6 +45,7 @@ import static com.taais.biz.domain.table.TargetIdentificationTaskTableDef.TARGET
  * 2024-08-17
  */
 @Service
+@Slf4j
 public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetIdentificationTaskMapper, TargetIdentificationTask> implements ITargetIdentificationTaskService {
     @Resource
     private TargetIdentificationTaskMapper targetIdentificationTaskMapper;
@@ -265,6 +267,8 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
 
     public static final String PATH_PREFIX = "/home/ObjectDetection_Web/task";
     public static final String WORK_DIR = "/home/ObjectDetection_Web";
+    //public static final String PATH_PREFIX = "ObjectDetection_Web/task";
+    //public static final String WORK_DIR = "ObjectDetection_Web";
     /**
      * 移动文件到对应文件夹
      * @param batch
@@ -297,6 +301,7 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
             for (DataVo dataVo : dataVoList) {
                 String[] strings = dataVo.getUrl().split("/profile");
                 String relativePath = WORK_DIR + strings[strings.length - 1];
+                relativePath = relativePath.replace("\\", "/").replace("//", "/");
                 File file = new File(relativePath);
                 System.out.println(file.getAbsolutePath());
                 if (file.exists()) {
@@ -312,6 +317,7 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                 if (dataVo.getLabelurl() != null) {
                     strings = dataVo.getLabelurl().split("/profile");
                     relativePath = WORK_DIR + strings[strings.length - 1];
+                    relativePath = relativePath.replace("\\", "/").replace("//", "/");
                     file = new File(relativePath);
                     if (file.exists()) {
                         try {
@@ -356,13 +362,16 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
         for (String batchNum : batches) {
             List<DataVo> dataVoList = dataService.getDataByBatchNum(batchNum);
             for (DataVo dataVo : dataVoList) {
+                log.info("dataVo: " + dataVo.getUrl());
                 String[] strings = null;
                 String relativePath = null;
                 File file = null;
 
                 if (dataVo.getLabelurl() != null) {
                     strings = dataVo.getLabelurl().split("/profile");
-                    relativePath = "ObjectDetection_Web" + strings[strings.length - 1];
+                    relativePath = WORK_DIR + strings[strings.length - 1];
+                    relativePath = relativePath.replace("\\", "/").replace("//", "/");
+                    log.info("relativePath: " + relativePath);
                     file = new File(relativePath);
                     if (file.exists()) {
                         if (!moveLabelFile) {
@@ -373,7 +382,7 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                             FileUtils.copyFile(file, dist);
                             System.out.println("file dist: " + dist.getAbsolutePath());
                         } catch (IOException e) {
-                            Log.debug("bug found");
+                            log.error("e happens: {}", e.getMessage());
                             continue;
                         }
                     }
@@ -382,11 +391,14 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
                 }
 
                 strings = dataVo.getUrl().split("/profile");
-                relativePath = "ObjectDetection_Web" + strings[strings.length - 1];
+                relativePath = WORK_DIR + strings[strings.length - 1];
+                relativePath = relativePath.replace("\\", "/").replace("//", "/");
+                log.info("relativePath url: " + relativePath);
                 file = new File(relativePath);
                 if (file.exists()) {
                     try {
                         File dist = new File(PATH_PREFIX + path + "/images/" + file.getName());
+                        log.info("get filename: {}", file.getName());
                         FileUtils.copyFile(file, dist);
                         System.out.println("file dist: " + dist.getAbsolutePath());
                     } catch (IOException e) {
@@ -412,7 +424,8 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
 
                 if (StringUtils.isNotEmpty(dataVo.getLabelurl())) {
                     strings = dataVo.getLabelurl().split("/profile");
-                    String relativePath = "ObjectDetection_Web" + strings[strings.length - 1];
+                    String relativePath = WORK_DIR + strings[strings.length - 1];
+                    relativePath = relativePath.replace("\\", "/").replace("//", "/");
                     File file = new File(relativePath);
                     if (file.exists()) {
                         return true;
@@ -450,6 +463,7 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
 
             for (int i = 0; i < testBatchNumList.size(); i++) {
                 String batchNum = testBatchNumList.get(i);
+                log.info("'batchNum' is: {}", batchNum);
                 TargetIdentificationSubtaskDetailsBo subtaskDetail = new TargetIdentificationSubtaskDetailsBo();
                 // 通过算法id获取算法配置
                 subtaskDetail.setSubtaskId(subtask.getId());
@@ -466,12 +480,15 @@ public class TargetIdentificationTaskServiceImpl extends BaseServiceImpl<TargetI
 
                 if (hasValidationSet(batchNum)) {
                     subtaskDetail.setName(algName + "_验证");
+                    subtaskDetail.setParameters(params.get(1));
                     copyFilesToPath(batchNum, subtaskPath, true);
                     subtaskDetail.setType(algorithmModelVo.getVerifyUrl() + ";;;" + records.get(algName));
                     subtaskDetailsService.insert(subtaskDetail);
+                    // reset to '测试'
                     subtaskPath = "/" + UUID.randomUUID().toString().replace("-", "_");
                     subtaskDetail.setPreprocessPath(subtaskPath);
                     subtaskDetail.setResultPath(subtaskPath + "/result");
+                    subtaskDetail.setParameters(params.get(2));
                 }
 
                 subtaskDetail.setName(algName + "_测试");

+ 3 - 0
taais-modules/taais-biz/src/main/resources/mapper/ag/AlgorithmModelMapper.xml

@@ -13,4 +13,7 @@
     <select id="getModelByAlgorithmId" resultType="com.taais.biz.domain.vo.AlgorithmModelVo">
         select * from algorithm_model am where am.algorithm_id = #{algorithmId}
     </select>
+    <select id="getModelById" resultType="com.taais.biz.domain.vo.AlgorithmModelVo">
+        select * from algorithm_model am where am.id = #{id}
+    </select>
 </mapper>