28968 6 hónapja
szülő
commit
85b9bd8e69

+ 63 - 39
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/DataAugmentationController.java

@@ -16,15 +16,15 @@ import java.util.stream.Stream;
 import com.alibaba.fastjson2.JSONArray;
 import com.alibaba.fastjson2.JSONObject;
 import com.mybatisflex.core.query.QueryWrapper;
+import com.taais.biz.domain.CommonAlgorithmConfig;
 import com.taais.biz.domain.DataAugmentation;
+import com.taais.biz.domain.bo.CommonAlgorithmConfigBo;
 import com.taais.biz.domain.bo.DataAugmentationBo;
 import com.taais.biz.domain.bo.DataAugmentationResultBo;
 import com.taais.biz.domain.bo.VideoStableStartResultBo;
 import com.taais.biz.domain.dto.Metric;
-import com.taais.biz.domain.vo.DataAugmentationVo;
-import com.taais.biz.domain.vo.ImageUrlPair;
-import com.taais.biz.domain.vo.TaskDictDataVo;
-import com.taais.biz.domain.vo.VideoUrl;
+import com.taais.biz.domain.vo.*;
+import com.taais.biz.service.ICommonAlgorithmConfigService;
 import com.taais.biz.service.IVideoStableService;
 import com.taais.common.core.config.TaaisConfig;
 import com.taais.common.core.service.OssService;
@@ -77,6 +77,9 @@ public class DataAugmentationController extends BaseController {
     private ISysDictDataService sysDictDataService;
     @Autowired
     private IVideoStableService videoStableService;
+
+    @Autowired
+    private ICommonAlgorithmConfigService commonAlgorithmConfigService;
     @GetMapping("/compare/num/{task_id}")
 
     public CommonResult getCompareNum(@PathVariable("task_id") Long taskId) {
@@ -207,27 +210,27 @@ public class DataAugmentationController extends BaseController {
                                 return path.getFileName().toString();
                             }).sorted().collect(Collectors.toList());
                         if ("目标毁伤评估".equals(taskType)) {
-                            QueryWrapper queryWrapper = new QueryWrapper();
-                            queryWrapper.eq("dict_type", "biz_data_augmentation");
-                            queryWrapper.eq("dict_label", "目标毁伤评估必须输出文件名称");
-                            List<SysDictData> list = sysDictDataService.list(queryWrapper);
-                            String[] split = list.get(0).getDictValue().split(",");
-                            if (list.isEmpty()) {
-                                origin.add("未在数据字典中定义目标毁伤评估必须输出文件名称");
-                                origin_list.add(origin);
-                                images.put("error", origin_list);
-                                return ResponseEntity.status(500).body(images);
-                            } else {
-
-                                for (String fileName: split) {
-                                    if (!outputFileList.contains(fileName)) {
-                                        stable.add("输出目录不存在结果文件:" + fileName);
-                                        stable_list.add(stable);
-                                        images.put("error", stable_list);
-                                        return ResponseEntity.status(500).body(images);
-                                    }
+//                            QueryWrapper queryWrapper = new QueryWrapper();
+//                            queryWrapper.eq("dict_type", "biz_data_augmentation");
+//                            queryWrapper.eq("dict_label", "目标毁伤评估必须输出文件名称");
+//                            List<SysDictData> list = sysDictDataService.list(queryWrapper);
+                            String[] split = new String[]{"diffIm.jpg","largest_salient_region.jpg","final_region.jpg"};
+//                            if (list.isEmpty()) {
+//                                origin.add("未在数据字典中定义目标毁伤评估必须输出文件名称");
+//                                origin_list.add(origin);
+//                                images.put("error", origin_list);
+//                                return ResponseEntity.status(500).body(images);
+//                            } else {
+
+                            for (String fileName: split) {
+                                if (!outputFileList.contains(fileName)) {
+                                    stable.add("输出目录不存在结果文件:" + fileName);
+                                    stable_list.add(stable);
+                                    images.put("error", stable_list);
+                                    return ResponseEntity.status(500).body(images);
                                 }
                             }
+
                             outputFileList.sort(Comparator.comparingInt(DataAugmentationController::extractNumber));
                             Map<String, Integer> indexMap = new HashMap<>();
                             for (int i = 0; i < split.length; i++) {
@@ -523,28 +526,49 @@ public class DataAugmentationController extends BaseController {
     }
 
     @SaCheckPermission("demo:dataAugmentation:query")
-    @GetMapping(value = "/getTaskDictData")
-    public CommonResult<List<TaskDictDataVo>> getTaskDictData() {
+    @PostMapping(value = "/getTaskDictData")
+    public CommonResult<List<TaskDictDataVo>> getTaskDictData(@RequestBody String module) {
+        module = module.substring(1, module.length() - 1);  //去除字符串两把的引号
         ArrayList<TaskDictDataVo> taskDictDataVos = new ArrayList<>();
+        //之前从数据字典中获取算法任务及其对应超参配置
+//        QueryWrapper queryWrapper = new QueryWrapper();
+//        queryWrapper.eq("dict_type", "biz_data_augmentation");
+//        queryWrapper.eq("dict_label", "任务类型");
+//        List<SysDictData> taskTypeList = sysDictDataService.list(queryWrapper);
+//        for (SysDictData sysDictData: taskTypeList) {
+//            String dictValue = sysDictData.getDictValue();
+//            queryWrapper.clear();
+//            queryWrapper.eq("dict_type", "biz_data_augmentation");
+//            queryWrapper.eq("dict_label", dictValue + "超参配置");
+//            List<SysDictData> hyparamList = sysDictDataService.list(queryWrapper);
+//            if (hyparamList.isEmpty()) {
+//                return CommonResult.fail("未在type为biz_data_augmentation的字典数据中设置" + dictValue + "的超参配置");
+//            }
+//            TaskDictDataVo taskDictDataVo = new TaskDictDataVo();
+//            taskDictDataVo.setTaskType(dictValue);
+//            taskDictDataVo.setHyperparameterConfiguration(hyparamList.get(0).getDictValue());
+//            taskDictDataVos.add(taskDictDataVo);
+//        }
+        //现在从新建的通用算法配置表中获取
+        //List<CommonAlgorithmConfig> list = commonAlgorithmConfigService.selectAll();
+//        System.out.println("model" + module);
         QueryWrapper queryWrapper = new QueryWrapper();
-        queryWrapper.eq("dict_type", "biz_data_augmentation");
-        queryWrapper.eq("dict_label", "任务类型");
-        List<SysDictData> taskTypeList = sysDictDataService.list(queryWrapper);
-        for (SysDictData sysDictData: taskTypeList) {
-            String dictValue = sysDictData.getDictValue();
-            queryWrapper.clear();
-            queryWrapper.eq("dict_type", "biz_data_augmentation");
-            queryWrapper.eq("dict_label", dictValue + "超参配置");
-            List<SysDictData> hyparamList = sysDictDataService.list(queryWrapper);
-            if (hyparamList.isEmpty()) {
-                return CommonResult.fail("未在type为biz_data_augmentation的字典数据中设置" + dictValue + "的超参配置");
-            }
+        queryWrapper.eq("module", module);
+        List<CommonAlgorithmConfig> list = commonAlgorithmConfigService.list(queryWrapper);
+        if (list.isEmpty()) {
+            return CommonResult.fail("通用算法配置表为空!");
+        }
+        for (CommonAlgorithmConfig commonAlgorithmConfig: list) {
+
             TaskDictDataVo taskDictDataVo = new TaskDictDataVo();
-            taskDictDataVo.setTaskType(dictValue);
-            taskDictDataVo.setHyperparameterConfiguration(hyparamList.get(0).getDictValue());
+            taskDictDataVo.setTaskType(commonAlgorithmConfig.getAlgorithmName());
+            taskDictDataVo.setHyperparameterConfiguration(commonAlgorithmConfig.getParameters());
             taskDictDataVos.add(taskDictDataVo);
+
+
         }
         return CommonResult.success(taskDictDataVos);
+
     }
     private static boolean hasSupportedExtension(String filePath, List<String> extensions) {
         for (String ext : extensions) {

+ 56 - 20
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataAugmentationServiceImpl.java

@@ -4,15 +4,15 @@ import cn.hutool.core.util.ObjectUtil;
 import com.mybatisflex.core.paginate.Page;
 import com.mybatisflex.core.query.QueryWrapper;
 import com.taais.biz.constant.BizConstant;
+import com.taais.biz.domain.CommonAlgorithmConfig;
 import com.taais.biz.domain.HttpResponseEntity;
 import com.taais.biz.domain.DataAugmentation;
 import com.taais.biz.domain.TransmissionObject;
-import com.taais.biz.domain.bo.DataAugmentationBo;
-import com.taais.biz.domain.bo.DataAugmentationResultBo;
-import com.taais.biz.domain.bo.DataAugmentationStartResultBo;
-import com.taais.biz.domain.bo.VideoStableStartResultBo;
+import com.taais.biz.domain.bo.*;
+import com.taais.biz.domain.vo.CommonAlgorithmConfigVo;
 import com.taais.biz.domain.vo.DataAugmentationVo;
 import com.taais.biz.mapper.DataAugmentationMapper;
+import com.taais.biz.service.ICommonAlgorithmConfigService;
 import com.taais.biz.service.IDataAugmentationService;
 import com.taais.biz.utils.ZipUtils;
 import com.taais.common.core.config.TaaisConfig;
@@ -76,6 +76,8 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
     @Autowired
     private ISysOssService ossService;
     @Autowired
+    private ICommonAlgorithmConfigService commonAlgorithmConfigService;
+    @Autowired
     private ISysDictDataService iSysDictDataService;
     @Resource
     private DataAugmentationMapper dataAugmentationMapper;
@@ -289,16 +291,32 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
             dataAugmentation.setInputPath(inputPath.toString());
             dataAugmentation.setAlgorithmPath(logPath.toString());
             dataAugmentation.setOutputPath(outputPath.toString());
-//            dataAugmentation.setStatus(BizConstant.ModelStatus.RUNNING);
-//            updateById(dataAugmentation);
-            SysDictDataBo sysDictDataBo = new SysDictDataBo();
-            sysDictDataBo.setDictLabel(dataAugmentation.getTaskType() + "开始url");
-            sysDictDataBo.setDictType("biz_data_augmentation");
-            List<SysDictDataVo> sysDictDataVos = iSysDictDataService.selectDictDataList(sysDictDataBo);
-            if (sysDictDataVos.size() == 0) {
-                return CommonResult.fail("未设置" + dataAugmentation.getTaskType() + "算法推理的url!请在数据字典中设置该算法的推理url!");
+            //由数据字典获取开始api
+//            SysDictDataBo sysDictDataBo = new SysDictDataBo();
+//            sysDictDataBo.setDictLabel(dataAugmentation.getTaskType() + "开始url");
+//            sysDictDataBo.setDictType("biz_data_augmentation");
+//            List<SysDictDataVo> sysDictDataVos = iSysDictDataService.selectDictDataList(sysDictDataBo);
+//            if (sysDictDataVos.size() == 0) {
+//                return CommonResult.fail("未设置" + dataAugmentation.getTaskType() + "算法推理的url!请在数据字典中设置该算法的推理url!");
+//            }
+            //由新建的通用算法配置表获取算法开始api
+//            CommonAlgorithmConfigBo commonAlgorithmConfigBo = new CommonAlgorithmConfigBo();
+//            commonAlgorithmConfigBo.setAlgorithmName(dataAugmentation.getTaskType());
+//
+//            List<CommonAlgorithmConfigVo> commonAlgorithmConfigVos = commonAlgorithmConfigService.selectList(commonAlgorithmConfigBo);
+//            if (commonAlgorithmConfigVos.size() == 0) {
+//                return CommonResult.fail("通用算法配置表中无" + dataAugmentation.getTaskType() + "算法配置");
+//            }
+//            String data_augmentation_start_url = commonAlgorithmConfigVos.get(0).getStartApi();
+            QueryWrapper queryWrapper = new QueryWrapper();
+            queryWrapper.eq("algorithm_name", dataAugmentation.getTaskType());
+            List<CommonAlgorithmConfig> list = commonAlgorithmConfigService.list(queryWrapper);
+            if (list.isEmpty()) {
+                return CommonResult.fail("通用算法配置表中无" + dataAugmentation.getTaskType() + "算法配置");
             }
-            String data_augmentation_start_url = sysDictDataVos.get(0).getDictValue();
+            CommonAlgorithmConfig commonAlgorithmConfig = list.get(0);
+
+            String data_augmentation_start_url = commonAlgorithmConfig.getStartApi();
             //设置传输对象
             TransmissionObject transmissionObject = new TransmissionObject();
             transmissionObject.setBizId(dataAugmentation.getId());
@@ -441,15 +459,33 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
 
     @Override
     public CommonResult stop(Long id) {
+        //从数据字典中获取停止api
         DataAugmentation dataAugmentation = getById(id);
-        SysDictDataBo sysDictDataBo = new SysDictDataBo();
-        sysDictDataBo.setDictLabel(dataAugmentation.getTaskType() + "停止url");
-        sysDictDataBo.setDictType("biz_data_augmentation");
-        List<SysDictDataVo> sysDictDataVos = iSysDictDataService.selectDictDataList(sysDictDataBo);
-        if (sysDictDataVos.size() == 0) {
-            return CommonResult.fail("未设置数据增强算法停止推理的url!请在数据字典中设置该算法的停止推理地址!");
+//        SysDictDataBo sysDictDataBo = new SysDictDataBo();
+//        sysDictDataBo.setDictLabel(dataAugmentation.getTaskType() + "停止url");
+//        sysDictDataBo.setDictType("biz_data_augmentation");
+//        List<SysDictDataVo> sysDictDataVos = iSysDictDataService.selectDictDataList(sysDictDataBo);
+//        if (sysDictDataVos.size() == 0) {
+//            return CommonResult.fail("未设置数据增强算法停止推理的url!请在数据字典中设置该算法的停止推理地址!");
+//        }
+//        String data_augmentation_stop_url = sysDictDataVos.get(0).getDictValue();
+        //从通用算法配置表中获取停止api
+//        CommonAlgorithmConfigBo commonAlgorithmConfigBo = new CommonAlgorithmConfigBo();
+//        commonAlgorithmConfigBo.setAlgorithmName(dataAugmentation.getTaskType());
+//
+//        List<CommonAlgorithmConfigVo> commonAlgorithmConfigVos = commonAlgorithmConfigService.selectList(commonAlgorithmConfigBo);
+        QueryWrapper queryWrapper = new QueryWrapper();
+        queryWrapper.eq("algorithm_name", dataAugmentation.getTaskType());
+        List<CommonAlgorithmConfig> list = commonAlgorithmConfigService.list(queryWrapper);
+        if (list.isEmpty()) {
+            return CommonResult.fail("通用算法配置表中无" + dataAugmentation.getTaskType() + "算法配置");
+        }
+        CommonAlgorithmConfig commonAlgorithmConfig = list.get(0);
+        if (commonAlgorithmConfig == null) {
+            return CommonResult.fail("通用算法配置表中无" + dataAugmentation.getTaskType() + "算法配置");
         }
-        String data_augmentation_stop_url = sysDictDataVos.get(0).getDictValue();
+        String data_augmentation_stop_url = commonAlgorithmConfig.getTerminateApi();
+        //设置传输对象
         TransmissionObject transmissionObject = new TransmissionObject();
         transmissionObject.setBizId(dataAugmentation.getId());
         transmissionObject.setBizType(dataAugmentation.getTaskType());